Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
11
Dockerfile
Normal file
11
Dockerfile
Normal file
@@ -0,0 +1,11 @@
|
||||
FROM registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8
|
||||
|
||||
# Keep the runtime stack from the known-good v8 image, but replace the
|
||||
# installed Python package with the repository's patched 0.16.1rc0 sources.
|
||||
WORKDIR /tmp
|
||||
|
||||
RUN rm -rf /usr/local/lib/python3.12/dist-packages/vllm \
|
||||
/usr/local/lib/python3.12/dist-packages/vllm-*.dist-info
|
||||
|
||||
COPY vllm /usr/local/lib/python3.12/dist-packages/vllm
|
||||
COPY vllm-0.16.1rc0+corex.4.4.0.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.16.1rc0+corex.4.4.0.dist-info
|
||||
60
README.md
Normal file
60
README.md
Normal file
@@ -0,0 +1,60 @@
|
||||
# bi_150-vllm
|
||||
|
||||
基于 `registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8` 的
|
||||
`vLLM 0.16.1rc0` 构建仓库,用于在 BI-V150 虚拟机环境中生成可直接运行的镜像。
|
||||
|
||||
## 改动说明
|
||||
|
||||
本仓库只保留构建镜像所需的最小内容:
|
||||
|
||||
- `vllm/`
|
||||
当前运行代码
|
||||
- `vllm-0.16.1rc0+corex.4.4.0.dist-info/`
|
||||
对应的包元数据
|
||||
- `Dockerfile`
|
||||
构建最终镜像
|
||||
|
||||
与基础镜像相比,本仓库保留的关键代码改动如下:
|
||||
|
||||
- 在 `vllm/platforms/__init__.py` 中修复 CUDA 平台识别逻辑
|
||||
- 当 NVML 不可用且出现 `NVML Shared Library Not Found` 一类错误时
|
||||
不再直接判定为非 CUDA 平台
|
||||
- 改为回退到 `torch.cuda.is_available()` 和
|
||||
`torch.cuda.device_count()` 继续判断 CUDA 是否可用
|
||||
|
||||
这个修复用于解决如下启动失败:
|
||||
|
||||
```text
|
||||
RuntimeError: Failed to infer device type
|
||||
```
|
||||
|
||||
## 构建镜像
|
||||
|
||||
在仓库根目录执行:
|
||||
|
||||
```bash
|
||||
docker build -t bi_150_vllm:0.16.1 .
|
||||
```
|
||||
|
||||
## 启动镜像
|
||||
|
||||
```bash
|
||||
docker run -dit \
|
||||
--name iluvatar_sut_submit_325811 \
|
||||
-p 38047:8000 \
|
||||
--privileged \
|
||||
-v /lib/modules:/lib/modules \
|
||||
-v /dev:/dev \
|
||||
-v /usr/src:/usr/src \
|
||||
-v /mnt/gpfs/leaderboard/modelHubXC/Amu/t1-1.5B:/model \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
--entrypoint vllm \
|
||||
bi_150_vllm:0.16.1 \
|
||||
serve /model \
|
||||
--port 8000 \
|
||||
--served-model-name llm \
|
||||
--max-model-len 2048 \
|
||||
--enforce-eager \
|
||||
--trust-remote-code \
|
||||
-tp 1
|
||||
```
|
||||
1
vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER
Normal file
1
vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER
Normal file
@@ -0,0 +1 @@
|
||||
pip
|
||||
211
vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA
Normal file
211
vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA
Normal file
@@ -0,0 +1,211 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: vllm
|
||||
Version: 0.16.1rc0+corex.4.4.0
|
||||
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
|
||||
Author: vLLM Team
|
||||
License-Expression: Apache-2.0
|
||||
Project-URL: Homepage, https://github.com/vllm-project/vllm
|
||||
Project-URL: Documentation, https://docs.vllm.ai/en/latest/
|
||||
Project-URL: Slack, https://slack.vllm.ai/
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Programming Language :: Python :: 3.13
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Intended Audience :: Information Technology
|
||||
Classifier: Intended Audience :: Science/Research
|
||||
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
||||
Requires-Python: <3.14,>=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
License-File: LICENSE
|
||||
Requires-Dist: regex
|
||||
Requires-Dist: cachetools
|
||||
Requires-Dist: psutil
|
||||
Requires-Dist: sentencepiece
|
||||
Requires-Dist: numpy==1.26.4
|
||||
Requires-Dist: requests>=2.26.0
|
||||
Requires-Dist: tqdm
|
||||
Requires-Dist: blake3
|
||||
Requires-Dist: py-cpuinfo
|
||||
Requires-Dist: transformers<5,>=4.56.0
|
||||
Requires-Dist: tokenizers>=0.21.1
|
||||
Requires-Dist: protobuf!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*,>=5.29.6
|
||||
Requires-Dist: fastapi[standard]>=0.115.0
|
||||
Requires-Dist: aiohttp>=3.13.3
|
||||
Requires-Dist: openai>=1.99.1
|
||||
Requires-Dist: pydantic>=2.12.0
|
||||
Requires-Dist: prometheus_client>=0.18.0
|
||||
Requires-Dist: pillow
|
||||
Requires-Dist: prometheus-fastapi-instrumentator>=7.0.0
|
||||
Requires-Dist: tiktoken>=0.6.0
|
||||
Requires-Dist: lm-format-enforcer==0.11.3
|
||||
Requires-Dist: llguidance<1.4.0,>=1.3.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" or platform_machine == "s390x" or platform_machine == "ppc64le"
|
||||
Requires-Dist: outlines_core==0.2.11
|
||||
Requires-Dist: diskcache==5.6.3
|
||||
Requires-Dist: lark==1.2.2
|
||||
Requires-Dist: xgrammar==0.1.29; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le"
|
||||
Requires-Dist: typing_extensions>=4.10
|
||||
Requires-Dist: filelock>=3.16.1
|
||||
Requires-Dist: partial-json-parser
|
||||
Requires-Dist: pyzmq>=25.0.0
|
||||
Requires-Dist: msgspec
|
||||
Requires-Dist: gguf>=0.17.0
|
||||
Requires-Dist: mistral_common[image]>=1.9.1
|
||||
Requires-Dist: pyyaml
|
||||
Requires-Dist: six>=1.16.0; python_version > "3.11"
|
||||
Requires-Dist: setuptools<81.0.0,>=77.0.3; python_version > "3.11"
|
||||
Requires-Dist: einops
|
||||
Requires-Dist: compressed-tensors==0.13.0
|
||||
Requires-Dist: depyf==0.20.0
|
||||
Requires-Dist: cloudpickle
|
||||
Requires-Dist: watchfiles
|
||||
Requires-Dist: python-json-logger
|
||||
Requires-Dist: ninja
|
||||
Requires-Dist: pybase64
|
||||
Requires-Dist: cbor2
|
||||
Requires-Dist: ijson
|
||||
Requires-Dist: setproctitle
|
||||
Requires-Dist: openai-harmony>=0.0.3
|
||||
Requires-Dist: anthropic>=0.71.0
|
||||
Requires-Dist: model-hosting-container-standards<1.0.0,>=0.1.13
|
||||
Requires-Dist: mcp
|
||||
Requires-Dist: grpcio
|
||||
Requires-Dist: grpcio-reflection
|
||||
Requires-Dist: opentelemetry-sdk>=1.27.0
|
||||
Requires-Dist: opentelemetry-api>=1.27.0
|
||||
Requires-Dist: opentelemetry-exporter-otlp>=1.27.0
|
||||
Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1
|
||||
Requires-Dist: numba==0.61.2
|
||||
Requires-Dist: ray[cgraph]>=2.48.0
|
||||
Provides-Extra: bench
|
||||
Requires-Dist: pandas; extra == "bench"
|
||||
Requires-Dist: matplotlib; extra == "bench"
|
||||
Requires-Dist: seaborn; extra == "bench"
|
||||
Requires-Dist: datasets; extra == "bench"
|
||||
Requires-Dist: scipy; extra == "bench"
|
||||
Provides-Extra: tensorizer
|
||||
Requires-Dist: tensorizer==2.10.1; extra == "tensorizer"
|
||||
Provides-Extra: fastsafetensors
|
||||
Requires-Dist: fastsafetensors>=0.2.2; extra == "fastsafetensors"
|
||||
Provides-Extra: runai
|
||||
Requires-Dist: runai-model-streamer[gcs,s3]>=0.15.3; extra == "runai"
|
||||
Provides-Extra: audio
|
||||
Requires-Dist: librosa; extra == "audio"
|
||||
Requires-Dist: scipy; extra == "audio"
|
||||
Requires-Dist: soundfile; extra == "audio"
|
||||
Requires-Dist: mistral_common[audio]; extra == "audio"
|
||||
Provides-Extra: video
|
||||
Provides-Extra: flashinfer
|
||||
Provides-Extra: otel
|
||||
Requires-Dist: opentelemetry-sdk>=1.26.0; extra == "otel"
|
||||
Requires-Dist: opentelemetry-api>=1.26.0; extra == "otel"
|
||||
Requires-Dist: opentelemetry-exporter-otlp>=1.26.0; extra == "otel"
|
||||
Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1; extra == "otel"
|
||||
Dynamic: license-file
|
||||
Dynamic: provides-extra
|
||||
Dynamic: requires-dist
|
||||
|
||||
<!-- markdownlint-disable MD001 MD041 -->
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png">
|
||||
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-light.png" width=55%>
|
||||
</picture>
|
||||
</p>
|
||||
|
||||
<h3 align="center">
|
||||
Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
🔥 We have built a vllm website to help you get started with vllm. Please visit [vllm.ai](https://vllm.ai) to learn more.
|
||||
For events, please visit [vllm.ai/events](https://vllm.ai/events) to join us.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
|
||||
|
||||
vLLM is fast with:
|
||||
|
||||
- State-of-the-art serving throughput
|
||||
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
|
||||
- Continuous batching of incoming requests
|
||||
- Fast model execution with CUDA/HIP graph
|
||||
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
|
||||
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
|
||||
- Speculative decoding
|
||||
- Chunked prefill
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||
- Prefix caching support
|
||||
- Multi-LoRA support
|
||||
|
||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||
|
||||
- Transformer-like LLMs (e.g., Llama)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
|
||||
- Embedding Models (e.g., E5-Mistral)
|
||||
- Multi-modal LLMs (e.g., LLaVA)
|
||||
|
||||
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||
|
||||
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
||||
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
|
||||
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome and value any contributions and collaborations.
|
||||
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||
|
||||
```bibtex
|
||||
@inproceedings{kwon2023efficient,
|
||||
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
## Contact Us
|
||||
|
||||
<!-- --8<-- [start:contact-us] -->
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
|
||||
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||
- For collaborations and partnerships, please contact us at [collaboration@vllm.ai](mailto:collaboration@vllm.ai)
|
||||
<!-- --8<-- [end:contact-us] -->
|
||||
|
||||
## Media Kit
|
||||
|
||||
- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit)
|
||||
3248
vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD
Normal file
3248
vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD
Normal file
File diff suppressed because it is too large
Load Diff
0
vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED
Normal file
0
vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED
Normal file
5
vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL
Normal file
5
vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL
Normal file
@@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (82.0.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
1
vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json
Normal file
1
vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json
Normal file
@@ -0,0 +1 @@
|
||||
{"archive_info": {"hash": "sha256=f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1", "hashes": {"sha256": "f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1"}}, "url": "file:///workspace/vllm-0.16.1rc0%2Bcorex.4.4.0-py3-none-any.whl"}
|
||||
6
vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt
Normal file
6
vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
[console_scripts]
|
||||
vllm = vllm.entrypoints.cli.main:main
|
||||
|
||||
[vllm.general_plugins]
|
||||
lora_filesystem_resolver = vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver
|
||||
lora_hf_hub_resolver = vllm.plugins.lora_resolvers.hf_hub_resolver:register_hf_hub_resolver
|
||||
201
vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE
Normal file
201
vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
1
vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt
Normal file
1
vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt
Normal file
@@ -0,0 +1 @@
|
||||
vllm
|
||||
244
vllm/.gitignore
vendored
Normal file
244
vllm/.gitignore
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
# version file generated by setuptools-scm
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/__init__.py
|
||||
!vllm/vllm_flash_attn/flash_attn_interface.py
|
||||
|
||||
# OpenAI triton kernels copied from source
|
||||
vllm/third_party/triton_kernels/*
|
||||
|
||||
# FlashMLA interface copied from source
|
||||
vllm/third_party/flashmla/flash_mla_interface.py
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
cmake-build-*/
|
||||
CMakeUserPresets.json
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
/.deps/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# generated files
|
||||
**/generated/**
|
||||
|
||||
# uv
|
||||
uv.lock
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples/*
|
||||
!docs/examples/README.md
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# Claude
|
||||
.claude/
|
||||
|
||||
# Codex
|
||||
.codex/
|
||||
|
||||
# Cursor
|
||||
.cursor/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
# Results
|
||||
*.csv
|
||||
|
||||
# Python pickle files
|
||||
*.pkl
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ignore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
|
||||
# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
|
||||
!vllm/benchmarks/lib/
|
||||
|
||||
# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
|
||||
vllm/grpc/vllm_engine_pb2.py
|
||||
vllm/grpc/vllm_engine_pb2_grpc.py
|
||||
vllm/grpc/vllm_engine_pb2.pyi
|
||||
|
||||
# Ignore generated cpu headers
|
||||
csrc/cpu/cpu_attn_dispatch_generated.h
|
||||
|
||||
107
vllm/__init__.py
Normal file
107
vllm/__init__.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
|
||||
# The version.py should be independent library, and we always import the
|
||||
# version library first. Such assumption is critical for some customization.
|
||||
from .version import __version__, __version_tuple__ # isort:skip
|
||||
|
||||
import typing
|
||||
|
||||
# The environment variables override should be imported before any other
|
||||
# modules to ensure that the environment variables are set before any
|
||||
# other modules are imported.
|
||||
import vllm.env_override # noqa: F401
|
||||
|
||||
MODULE_ATTRS = {
|
||||
"bc_linter_skip": "._bc_linter:bc_linter_skip",
|
||||
"bc_linter_include": "._bc_linter:bc_linter_include",
|
||||
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
|
||||
"EngineArgs": ".engine.arg_utils:EngineArgs",
|
||||
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||
"LLMEngine": ".engine.llm_engine:LLMEngine",
|
||||
"LLM": ".entrypoints.llm:LLM",
|
||||
"initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster",
|
||||
"PromptType": ".inputs:PromptType",
|
||||
"TextPrompt": ".inputs:TextPrompt",
|
||||
"TokensPrompt": ".inputs:TokensPrompt",
|
||||
"ModelRegistry": ".model_executor.models:ModelRegistry",
|
||||
"SamplingParams": ".sampling_params:SamplingParams",
|
||||
"PoolingParams": ".pooling_params:PoolingParams",
|
||||
"ClassificationOutput": ".outputs:ClassificationOutput",
|
||||
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
|
||||
"CompletionOutput": ".outputs:CompletionOutput",
|
||||
"EmbeddingOutput": ".outputs:EmbeddingOutput",
|
||||
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
|
||||
"PoolingOutput": ".outputs:PoolingOutput",
|
||||
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
|
||||
"RequestOutput": ".outputs:RequestOutput",
|
||||
"ScoringOutput": ".outputs:ScoringOutput",
|
||||
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
|
||||
}
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (
|
||||
ClassificationOutput,
|
||||
ClassificationRequestOutput,
|
||||
CompletionOutput,
|
||||
EmbeddingOutput,
|
||||
EmbeddingRequestOutput,
|
||||
PoolingOutput,
|
||||
PoolingRequestOutput,
|
||||
RequestOutput,
|
||||
ScoringOutput,
|
||||
ScoringRequestOutput,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.executor.ray_utils import initialize_ray_cluster
|
||||
|
||||
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||
else:
|
||||
|
||||
def __getattr__(name: str) -> typing.Any:
|
||||
from importlib import import_module
|
||||
|
||||
if name in MODULE_ATTRS:
|
||||
module_name, attr_name = MODULE_ATTRS[name].split(":")
|
||||
module = import_module(module_name, __package__)
|
||||
return getattr(module, attr_name)
|
||||
else:
|
||||
raise AttributeError(f"module {__package__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"bc_linter_skip",
|
||||
"bc_linter_include",
|
||||
"__version_tuple__",
|
||||
"LLM",
|
||||
"ModelRegistry",
|
||||
"PromptType",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"PoolingOutput",
|
||||
"PoolingRequestOutput",
|
||||
"EmbeddingOutput",
|
||||
"EmbeddingRequestOutput",
|
||||
"ClassificationOutput",
|
||||
"ClassificationRequestOutput",
|
||||
"ScoringOutput",
|
||||
"ScoringRequestOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
"AsyncEngineArgs",
|
||||
"initialize_ray_cluster",
|
||||
"PoolingParams",
|
||||
]
|
||||
1810
vllm/_aiter_ops.py
Normal file
1810
vllm/_aiter_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
54
vllm/_bc_linter.py
Normal file
54
vllm/_bc_linter.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# vllm/_bc_linter.py
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(obj: T) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||
|
||||
|
||||
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
No-op decorator to mark symbols/files for BC-linter suppression.
|
||||
|
||||
Usage:
|
||||
@bc_linter_skip
|
||||
def legacy_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(obj: T) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||
|
||||
|
||||
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
Usage:
|
||||
@bc_linter_include
|
||||
def public_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
__all__ = ["bc_linter_skip", "bc_linter_include"]
|
||||
4238
vllm/_custom_ops.py
Normal file
4238
vllm/_custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
96
vllm/_oink_ops.py
Normal file
96
vllm/_oink_ops.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Small helper wrappers for external Oink Blackwell custom ops.
|
||||
|
||||
vLLM does not depend on the external Oink repository/package. When an external
|
||||
plugin registers torch.library.custom_op entrypoints under the `oink::`
|
||||
namespace (e.g. via vLLM's general_plugins mechanism) and
|
||||
`VLLM_USE_OINK_OPS=1` is set, vLLM can route eligible calls to those ops.
|
||||
|
||||
This module provides:
|
||||
- A single place to probe Oink op availability at module init time
|
||||
(outside torch.compile tracing), and
|
||||
- Thin wrappers around the torch.ops entrypoints for use in CUDA fast paths,
|
||||
without introducing graph breaks.
|
||||
|
||||
Important:
|
||||
Do not call the availability helpers in a compiled region. They may call
|
||||
functions decorated with `torch._dynamo.disable` to safely check
|
||||
conditions that should not be traced.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch._dynamo import disable as _dynamo_disable # type: ignore[attr-defined]
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
def _dynamo_disable(fn: Callable): # type: ignore[misc]
|
||||
return fn
|
||||
|
||||
|
||||
def _has_oink_op(op_name: str) -> bool:
|
||||
"""Check if a specific oink op is registered."""
|
||||
return hasattr(torch.ops, "oink") and hasattr(torch.ops.oink, op_name)
|
||||
|
||||
|
||||
@_dynamo_disable
|
||||
def is_oink_available_for_device(device_index: int) -> bool:
|
||||
"""Return True if Oink ops are registered and device is SM100+.
|
||||
|
||||
This function is intended to be called during module initialization
|
||||
(e.g., in RMSNorm.__init__), not in the forward path.
|
||||
|
||||
External plugins are expected to gate registration on SM100+ and
|
||||
VLLM_USE_OINK_OPS=1, so if the ops are present they should be usable.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
major, minor = torch.cuda.get_device_capability(device_index)
|
||||
sm = 10 * major + minor
|
||||
if sm < 100:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return _has_oink_op("rmsnorm")
|
||||
|
||||
|
||||
def has_fused_add_rms_norm() -> bool:
|
||||
"""Return True if the in-place fused op is registered."""
|
||||
return _has_oink_op("fused_add_rms_norm")
|
||||
|
||||
|
||||
def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Call `torch.ops.oink.rmsnorm`.
|
||||
|
||||
This wrapper is safe to call in torch.compile regions.
|
||||
"""
|
||||
return torch.ops.oink.rmsnorm(x, weight, eps)
|
||||
|
||||
|
||||
def fused_add_rms_norm_(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
) -> None:
|
||||
"""Call `torch.ops.oink.fused_add_rms_norm` (mutates x and residual)."""
|
||||
torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps)
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convenience wrapper returning (x, residual) after in-place mutation."""
|
||||
fused_add_rms_norm_(x, residual, weight, eps)
|
||||
return x, residual
|
||||
159
vllm/_xpu_ops.py
Normal file
159
vllm/_xpu_ops.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def register_fake(fn):
|
||||
return lambda name: fn
|
||||
else:
|
||||
try:
|
||||
from torch.library import register_fake
|
||||
except ImportError:
|
||||
from torch.library import impl_abstract as register_fake
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
|
||||
|
||||
@register_fake("_xpu_C::fp8_gemm_w8a16")
|
||||
def _fp8_gemm_w8a16_fake(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
M = input_2d.size(0)
|
||||
N = q_weight.size(1)
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
|
||||
|
||||
@register_fake("_xpu_C::int4_gemm_w4a16")
|
||||
def _int4_gemm_w4a16_fake(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
weight_scale: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int,
|
||||
group_idx: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
M = input_2d.size(0)
|
||||
N = q_weight.size(1)
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
class xpu_ops:
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
out: torch.Tensor | None = None,
|
||||
block_table: torch.Tensor | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
window_size: list[int] | None = None,
|
||||
softcap: float | None = 0.0,
|
||||
seqused_k: torch.Tensor | None = None,
|
||||
cu_seqlens_k: torch.Tensor | None = None,
|
||||
# passed in qwen vl
|
||||
dropout_p: float = 0.0,
|
||||
# The following parameters are not used in xpu kernel currently,
|
||||
# we keep API compatible to CUDA's.
|
||||
scheduler_metadata=None,
|
||||
fa_version: int = 2,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
num_splits=0,
|
||||
return_softmax_lse: bool | None = False,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
):
|
||||
assert cu_seqlens_k is not None or seqused_k is not None, (
|
||||
"cu_seqlens_k or seqused_k must be provided"
|
||||
)
|
||||
assert cu_seqlens_k is None or seqused_k is None, (
|
||||
"cu_seqlens_k and seqused_k cannot be provided at the same time"
|
||||
)
|
||||
assert block_table is None or seqused_k is not None, (
|
||||
"when enable block_table, seqused_k is needed"
|
||||
)
|
||||
assert block_table is not None or cu_seqlens_k is not None, (
|
||||
"when block_table is disabled, cu_seqlens_k is needed"
|
||||
)
|
||||
if out is None:
|
||||
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1]) # noqa: F841
|
||||
|
||||
# In encode attention, k and v maybe not contiguous and current
|
||||
# kernel can't handle it
|
||||
if block_table is None:
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
return flash_attn_varlen_func(
|
||||
out=out,
|
||||
q=q.contiguous(),
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
block_table=block_table,
|
||||
s_aux=s_aux,
|
||||
window_size=real_window_size,
|
||||
# alibi_slopes = alibi_slopes,
|
||||
# softcap=softcap,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler_metadata(
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
headdim,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype=torch.bfloat16,
|
||||
headdim_v=None,
|
||||
cu_seqlens_q: torch.Tensor | None = None,
|
||||
cu_seqlens_k_new: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = None,
|
||||
page_size: int | None = None,
|
||||
max_seqlen_k_new=0,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
has_softcap=False,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
|
||||
)
|
||||
return None
|
||||
0
vllm/assets/__init__.py
Normal file
0
vllm/assets/__init__.py
Normal file
43
vllm/assets/audio.py
Normal file
43
vllm/assets/audio.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
ASSET_DIR = "multimodal_asset"
|
||||
|
||||
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioAsset:
|
||||
name: AudioAssetName
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}.ogg"
|
||||
|
||||
@property
|
||||
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||
audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
|
||||
return librosa.load(audio_path, sr=None)
|
||||
|
||||
def get_local_path(self) -> Path:
|
||||
return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||
40
vllm/assets/base.py
Normal file
40
vllm/assets/base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
|
||||
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
"""Get the path to the cache for storing downloaded assets."""
|
||||
path = Path(envs.VLLM_ASSETS_CACHE)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vllm_public_assets(filename: str, s3_prefix: str | None = None) -> Path:
|
||||
"""
|
||||
Download an asset file from `s3://vllm-public-assets`
|
||||
and return the path to the downloaded file.
|
||||
"""
|
||||
asset_directory = get_cache_dir() / "vllm_public_assets"
|
||||
asset_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
asset_path = asset_directory / filename
|
||||
if not asset_path.exists():
|
||||
if s3_prefix is not None:
|
||||
filename = s3_prefix + "/" + filename
|
||||
global_http_connection.download_file(
|
||||
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
||||
asset_path,
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
return asset_path
|
||||
62
vllm/assets/image.py
Normal file
62
vllm/assets/image.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import get_vllm_public_assets
|
||||
|
||||
VLM_IMAGES_DIR = "vision_model_images"
|
||||
|
||||
ImageAssetName = Literal[
|
||||
"stop_sign",
|
||||
"cherry_blossom",
|
||||
"hato",
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
|
||||
"Grayscale_8bits_palette_sample_image",
|
||||
"1280px-Venn_diagram_rgb",
|
||||
"RGBA_comp",
|
||||
"237-400x300",
|
||||
"231-200x300",
|
||||
"27-500x500",
|
||||
"17-150x600",
|
||||
"handelsblatt-preview",
|
||||
"paper-11",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageAsset:
|
||||
name: ImageAssetName
|
||||
|
||||
def get_path(self, ext: str) -> Path:
|
||||
"""
|
||||
Return s3 path for given image.
|
||||
"""
|
||||
return get_vllm_public_assets(
|
||||
filename=f"{self.name}.{ext}", s3_prefix=VLM_IMAGES_DIR
|
||||
)
|
||||
|
||||
@property
|
||||
def pil_image(self) -> Image.Image:
|
||||
return self.pil_image_ext(ext="jpg")
|
||||
|
||||
def pil_image_ext(self, ext: str) -> Image.Image:
|
||||
image_path = self.get_path(ext=ext)
|
||||
return Image.open(image_path)
|
||||
|
||||
@property
|
||||
def image_embeds(self) -> torch.Tensor:
|
||||
"""
|
||||
Image embeddings, only used for testing purposes with llava 1.5.
|
||||
"""
|
||||
image_path = self.get_path("pt")
|
||||
return torch.load(image_path, map_location="cpu", weights_only=True)
|
||||
|
||||
def read_bytes(self, ext: str) -> bytes:
|
||||
p = Path(self.get_path(ext))
|
||||
return p.read_bytes()
|
||||
149
vllm/assets/video.py
Normal file
149
vllm/assets/video.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .base import get_cache_dir
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def download_video_asset(filename: str) -> str:
|
||||
"""
|
||||
Download and open an image from huggingface
|
||||
repo: raushan-testing-hf/videos-test
|
||||
"""
|
||||
video_directory = get_cache_dir() / "video-example-data"
|
||||
video_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_path = video_directory / filename
|
||||
video_path_str = str(video_path)
|
||||
if not video_path.exists():
|
||||
video_path_str = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test",
|
||||
filename=filename,
|
||||
repo_type="dataset",
|
||||
cache_dir=video_directory,
|
||||
)
|
||||
return video_path_str
|
||||
|
||||
|
||||
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video file {path}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frames = []
|
||||
|
||||
num_frames = num_frames if num_frames > 0 else total_frames
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
for idx in range(total_frames):
|
||||
ok = cap.grab() # next img
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_indices: # only decompress needed
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
# OpenCV uses BGR format, we need to convert it to RGB
|
||||
# for PIL and transformers compatibility
|
||||
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
|
||||
frames = np.stack(frames)
|
||||
if len(frames) < num_frames:
|
||||
raise ValueError(
|
||||
f"Could not read enough frames from video file {path}"
|
||||
f" (expected {num_frames} frames, got {len(frames)})"
|
||||
)
|
||||
return frames
|
||||
|
||||
|
||||
def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]:
|
||||
frames = video_to_ndarrays(path, num_frames)
|
||||
return [Image.fromarray(frame) for frame in frames]
|
||||
|
||||
|
||||
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video file {path}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
if num_frames == -1 or num_frames > total_frames:
|
||||
num_frames = total_frames
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": duration / num_frames,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": num_frames == total_frames,
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
VideoAssetName = Literal["baby_reading"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoAsset:
|
||||
name: VideoAssetName
|
||||
num_frames: int = -1
|
||||
|
||||
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
|
||||
"baby_reading": "sample_demo_1.mp4",
|
||||
}
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self._NAME_TO_FILE[self.name]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str:
|
||||
return download_video_asset(self.filename)
|
||||
|
||||
@property
|
||||
def pil_images(self) -> list[Image.Image]:
|
||||
ret = video_to_pil_images_list(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def np_ndarrays(self) -> npt.NDArray:
|
||||
ret = video_to_ndarrays(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
ret = video_get_metadata(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
def get_audio(self, sampling_rate: float | None = None) -> npt.NDArray:
|
||||
"""
|
||||
Read audio data from the video asset, used in Qwen2.5-Omni examples.
|
||||
|
||||
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
|
||||
"""
|
||||
return librosa.load(self.video_path, sr=sampling_rate)[0]
|
||||
109
vllm/beam_search.py
Normal file
109
vllm/beam_search.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.inputs import TokenInputs, token_inputs
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
"""A sequence for beam search.
|
||||
It keeps track of the tokens and the log probability of the sequence.
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
|
||||
orig_prompt: TokenInputs | MultiModalInputs
|
||||
|
||||
# The tokens include the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
lora_request: LoRARequest | None = None
|
||||
cum_logprob: float = 0.0
|
||||
text: str | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
def get_prompt(self):
|
||||
prompt = self.orig_prompt
|
||||
|
||||
prompt_text = prompt.get("prompt")
|
||||
cache_salt = prompt.get("cache_salt")
|
||||
|
||||
if prompt["type"] == "token":
|
||||
return token_inputs(
|
||||
self.tokens,
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
return mm_inputs(
|
||||
prompt_token_ids=self.tokens,
|
||||
mm_kwargs=prompt["mm_kwargs"],
|
||||
mm_hashes=prompt["mm_hashes"],
|
||||
mm_placeholders=prompt["mm_placeholders"],
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchOutput:
|
||||
"""The output of beam search.
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
|
||||
sequences: list[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
def __init__(
|
||||
self,
|
||||
prompt: TokenInputs | MultiModalInputs,
|
||||
lora_request: LoRARequest | None = None,
|
||||
logprobs: list[dict[int, Logprob]] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.beams: list[BeamSearchSequence] = [
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=prompt["prompt_token_ids"],
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
lora_request=lora_request,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
self.completed: list[BeamSearchSequence] = []
|
||||
|
||||
|
||||
def get_beam_search_score(
|
||||
tokens: list[int],
|
||||
cumulative_logprob: float,
|
||||
eos_token_id: int,
|
||||
length_penalty: float = 1.0,
|
||||
) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
seq_len = len(tokens)
|
||||
if tokens[-1] == eos_token_id:
|
||||
seq_len -= 1
|
||||
|
||||
return cumulative_logprob / (seq_len**length_penalty)
|
||||
|
||||
|
||||
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(
|
||||
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
||||
)
|
||||
|
||||
return sort_beams_key
|
||||
0
vllm/benchmarks/__init__.py
Normal file
0
vllm/benchmarks/__init__.py
Normal file
3453
vllm/benchmarks/datasets.py
Normal file
3453
vllm/benchmarks/datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
172
vllm/benchmarks/latency.py
Normal file
172
vllm/benchmarks/latency.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
||||
)
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# V1 enables prefix caching by default which skews the latency
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len + args.output_len
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
def run_to_completion(do_profile: bool = False):
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(do_profile=False)
|
||||
|
||||
if args.profile:
|
||||
profiler_config = engine_args.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
print(
|
||||
"Profiling with torch profiler (results will be saved to"
|
||||
f" {profiler_config.torch_profiler_dir})..."
|
||||
)
|
||||
elif profiler_config.profiler == "cuda":
|
||||
print("Profiling with cuda profiler ...")
|
||||
run_to_completion(do_profile=True)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Bench iterations"):
|
||||
latencies.append(run_to_completion(do_profile=False))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f"{percentage}% percentile latency: {percentile} seconds")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
539
vllm/benchmarks/mm_processor.py
Normal file
539
vllm/benchmarks/mm_processor.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
r"""Benchmark multimodal processor latency.
|
||||
|
||||
This benchmark measures the latency of the mm processor module
|
||||
using multimodal prompts from datasets.
|
||||
MM processor stats are automatically enabled.
|
||||
|
||||
Run:
|
||||
vllm bench mm-processor \
|
||||
--model <your_model> \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 10 \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
MultiModalConversationDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from vllm.benchmarks.throughput import get_requests
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
if TYPE_CHECKING: # Avoid having to mock during docs build
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
else:
|
||||
LLMEngine = object
|
||||
|
||||
|
||||
def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Get all multimodal timing stats from the LLM engine.
|
||||
|
||||
Collects both preprocessing stats (HF processor, hashing, cache lookup,
|
||||
prompt update) and encoder forward pass timing, merged by request_id.
|
||||
|
||||
Args:
|
||||
llm_engine: The LLM engine (has input_processor and workers).
|
||||
|
||||
Returns:
|
||||
Dictionary mapping request_id to merged stats dict containing
|
||||
both preprocessing and encoder timing metrics.
|
||||
|
||||
Example:
|
||||
{
|
||||
'request-123': {
|
||||
'get_mm_hashes_secs': 0.02,
|
||||
'get_cache_missing_items_secs': 0.01,
|
||||
'apply_hf_processor_secs': 0.45,
|
||||
'merge_mm_kwargs_secs': 0.01,
|
||||
'apply_prompt_updates_secs': 0.03,
|
||||
'preprocessor_total_secs': 0.51,
|
||||
'encoder_forward_secs': 0.23,
|
||||
'num_encoder_calls': 1
|
||||
}
|
||||
}
|
||||
"""
|
||||
observability_config = llm_engine.vllm_config.observability_config
|
||||
if not observability_config or not observability_config.enable_mm_processor_stats:
|
||||
return {}
|
||||
|
||||
renderer = llm_engine.renderer
|
||||
mm_processor_stats = renderer._mm_timing_registry.stat()
|
||||
|
||||
encoder_stats = dict[str, dict[str, float]]()
|
||||
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
|
||||
if not worker_stats:
|
||||
continue
|
||||
|
||||
for request_id, stats_dict in worker_stats.items():
|
||||
if request_id not in encoder_stats:
|
||||
encoder_stats[request_id] = dict(stats_dict)
|
||||
else:
|
||||
# Aggregate timing metrics across workers
|
||||
current_time = encoder_stats[request_id].get(
|
||||
"encoder_forward_secs", 0.0
|
||||
)
|
||||
new_time = stats_dict.get("encoder_forward_secs", 0.0)
|
||||
encoder_stats[request_id]["encoder_forward_secs"] = max(
|
||||
current_time, new_time
|
||||
)
|
||||
|
||||
current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
|
||||
new_calls = stats_dict.get("num_encoder_calls", 0)
|
||||
encoder_stats[request_id]["num_encoder_calls"] = max(
|
||||
current_calls, new_calls
|
||||
)
|
||||
|
||||
merged_stats = dict[str, dict[str, float]]()
|
||||
|
||||
for request_id, prep_dict in mm_processor_stats.items():
|
||||
merged_stats[request_id] = dict(prep_dict)
|
||||
|
||||
for request_id, enc_dict in encoder_stats.items():
|
||||
if request_id in merged_stats:
|
||||
merged_stats[request_id].update(enc_dict)
|
||||
continue
|
||||
|
||||
# In V1 engine, the request_id in encoder_stats has a suffix
|
||||
# appended to the original request_id (which is used in
|
||||
# preprocessing_stats).
|
||||
# We try to strip the suffix to find the matching request.
|
||||
possible_original_id = request_id.rpartition("-")[0]
|
||||
if possible_original_id and possible_original_id in merged_stats:
|
||||
merged_stats[possible_original_id].update(enc_dict)
|
||||
else:
|
||||
merged_stats[request_id] = dict(enc_dict)
|
||||
|
||||
return merged_stats
|
||||
|
||||
|
||||
def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
|
||||
"""
|
||||
Collect multimodal processor timing stats.
|
||||
Returns a dictionary mapping stage names to lists of timing values (in seconds).
|
||||
"""
|
||||
all_stats = get_timing_stats_from_engine(llm_engine)
|
||||
|
||||
stats_by_stage = defaultdict[str, list[float]](list)
|
||||
|
||||
for stats_dict in all_stats.values():
|
||||
for stat_key, stat_val in stats_dict.items():
|
||||
stats_by_stage[stat_key].append(stat_val)
|
||||
|
||||
return stats_by_stage
|
||||
|
||||
|
||||
def calculate_mm_processor_metrics(
|
||||
stats_by_stage: dict[str, list[float]],
|
||||
selected_percentiles: list[float],
|
||||
*,
|
||||
unit: Literal["us", "ms", "s"] = "ms",
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Calculate aggregate metrics from stats by stage.
|
||||
"""
|
||||
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
|
||||
unit_mult = unit2mult[unit]
|
||||
|
||||
metrics = {}
|
||||
|
||||
for stage, times in stats_by_stage.items():
|
||||
stage_name = stage.replace("_secs", "_" + unit)
|
||||
|
||||
if not times:
|
||||
metrics[stage_name] = {
|
||||
"mean": 0.0,
|
||||
"median": 0.0,
|
||||
"std": 0.0,
|
||||
**{f"p{p}": 0.0 for p in selected_percentiles},
|
||||
}
|
||||
continue
|
||||
|
||||
is_count_metric = stage == "num_encoder_calls"
|
||||
values = times if is_count_metric else [t * unit_mult for t in times]
|
||||
|
||||
metrics[stage_name] = {
|
||||
"mean": float(np.mean(values)),
|
||||
"median": float(np.median(values)),
|
||||
"std": float(np.std(values)),
|
||||
**{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments for mm_processor benchmark.
|
||||
"""
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
if not hasattr(args, "dataset_path"):
|
||||
args.dataset_path = None
|
||||
if not hasattr(args, "lora_path"):
|
||||
args.lora_path = None
|
||||
if not hasattr(args, "max_loras"):
|
||||
args.max_loras = None
|
||||
|
||||
if args.dataset_name == "hf" and not args.dataset_path:
|
||||
raise ValueError(
|
||||
"--dataset-path is required when using --dataset-name hf. "
|
||||
"For multimodal benchmarking, specify a dataset like "
|
||||
"'lmarena-ai/VisionArena-Chat'."
|
||||
)
|
||||
if args.dataset_name == "hf":
|
||||
supported_mm_datasets = (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
)
|
||||
if args.dataset_path not in supported_mm_datasets:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not a supported multimodal dataset. "
|
||||
f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
|
||||
)
|
||||
|
||||
|
||||
def benchmark_multimodal_processor(
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run the multimodal processor benchmark.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
validate_args(args)
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
tokenizer = llm.get_tokenizer()
|
||||
requests = get_requests(args, tokenizer)
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = [request.prompt for request in requests]
|
||||
expected_output_lens = [request.expected_output_len for request in requests]
|
||||
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
n=1,
|
||||
temperature=0.0,
|
||||
max_tokens=output_len,
|
||||
detokenize=True,
|
||||
)
|
||||
for output_len in expected_output_lens
|
||||
]
|
||||
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
|
||||
freeze_gc_heap()
|
||||
|
||||
num_warmups = getattr(args, "num_warmups", 0)
|
||||
if num_warmups > 0:
|
||||
print(f"Processing {num_warmups} warmup requests...")
|
||||
# Create a temporary args object for warmup requests
|
||||
warmup_args = argparse.Namespace(**vars(args))
|
||||
warmup_args.num_prompts = num_warmups
|
||||
warmup_args.seed += 1
|
||||
warmup_requests = get_requests(warmup_args, tokenizer)
|
||||
warmup_prompts = [req.prompt for req in warmup_requests]
|
||||
warmup_output_lens = [req.expected_output_len for req in warmup_requests]
|
||||
warmup_sampling_params = [
|
||||
SamplingParams(max_tokens=output_len) for output_len in warmup_output_lens
|
||||
]
|
||||
llm.chat(
|
||||
warmup_prompts,
|
||||
warmup_sampling_params,
|
||||
use_tqdm=not getattr(args, "disable_tqdm", False),
|
||||
)
|
||||
|
||||
# Clear stats from warmup requests
|
||||
collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
print(f"Processing {len(prompts)} requests...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
outputs = llm.chat(
|
||||
prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
total_time = end_time - start_time
|
||||
|
||||
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
if not any(mm_stats_by_stage.values()):
|
||||
print(
|
||||
"\n⚠️ Warning: No MM processor stats found in registry.\n"
|
||||
" This may indicate that:\n"
|
||||
" - No multimodal requests were processed\n"
|
||||
" - Stats were already retrieved (registry is cleared after retrieval)\n"
|
||||
)
|
||||
|
||||
mm_processor_metrics = calculate_mm_processor_metrics(
|
||||
mm_stats_by_stage, selected_percentiles
|
||||
)
|
||||
|
||||
completed = len([o for o in outputs if o.finished])
|
||||
failed = len(outputs) - completed
|
||||
|
||||
e2el_times = []
|
||||
for output in outputs:
|
||||
if not output.finished or output.metrics is None:
|
||||
continue
|
||||
metrics = output.metrics
|
||||
# Calculate E2E latency as: TTFT + (last_token_ts - first_token_ts)
|
||||
if (
|
||||
getattr(metrics, "first_token_latency", None) is not None
|
||||
and getattr(metrics, "last_token_ts", None) is not None
|
||||
and getattr(metrics, "first_token_ts", None) is not None
|
||||
):
|
||||
ttft = metrics.first_token_latency
|
||||
# Decode time is the duration between the first and last token generation
|
||||
decode_time = max(0.0, metrics.last_token_ts - metrics.first_token_ts)
|
||||
e2el_times.append((ttft + decode_time) * 1000)
|
||||
|
||||
if not e2el_times and completed > 0:
|
||||
print(
|
||||
"\n⚠️ Warning: Detailed end-to-end latency metrics not available.\n"
|
||||
" Falling back to average request latency "
|
||||
"(total_time / num_completed_requests).\n"
|
||||
)
|
||||
avg_time_per_request = total_time / completed
|
||||
e2el_times = [avg_time_per_request * 1000] * completed
|
||||
|
||||
if e2el_times:
|
||||
mean_e2el_ms = float(np.mean(e2el_times))
|
||||
median_e2el_ms = float(np.median(e2el_times))
|
||||
std_e2el_ms = float(np.std(e2el_times))
|
||||
percentiles_e2el_ms = [
|
||||
(p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
|
||||
]
|
||||
else:
|
||||
mean_e2el_ms = 0.0
|
||||
median_e2el_ms = 0.0
|
||||
std_e2el_ms = 0.0
|
||||
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
|
||||
|
||||
encoder_summary = {}
|
||||
if (
|
||||
"num_encoder_calls" in mm_stats_by_stage
|
||||
and mm_stats_by_stage["num_encoder_calls"]
|
||||
):
|
||||
encoder_calls = mm_stats_by_stage["num_encoder_calls"]
|
||||
encoder_summary = {
|
||||
"total_encoder_calls": int(sum(encoder_calls)),
|
||||
"num_requests_with_encoder_calls": len(encoder_calls),
|
||||
}
|
||||
|
||||
benchmark_result = {
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"mean_e2el_ms": mean_e2el_ms,
|
||||
"median_e2el_ms": median_e2el_ms,
|
||||
"std_e2el_ms": std_e2el_ms,
|
||||
"percentiles_e2el_ms": percentiles_e2el_ms,
|
||||
"mm_processor_stats": mm_processor_metrics,
|
||||
"encoder_summary": encoder_summary,
|
||||
}
|
||||
|
||||
return benchmark_result
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add CLI arguments for the multimodal processor benchmark."""
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.set_defaults(enable_mm_processor_stats=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random-mm",
|
||||
choices=["random-mm", "hf"],
|
||||
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-warmups",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup prompts to process.",
|
||||
)
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
add_random_dataset_base_args,
|
||||
add_random_multimodal_dataset_args,
|
||||
)
|
||||
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
# HuggingFace dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset file or HuggingFace dataset name "
|
||||
"(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HuggingFace dataset (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. "
|
||||
"Overrides the default output lengths from the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the benchmark results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
action="store_true",
|
||||
help="Disable tqdm progress bar.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
"""Main entry point for the multimodal processor benchmark."""
|
||||
|
||||
print("Starting multimodal processor benchmark...")
|
||||
result = benchmark_multimodal_processor(args)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Multimodal Processor Benchmark Results")
|
||||
print("=" * 80)
|
||||
|
||||
if "mm_processor_stats" in result:
|
||||
print("\nMM Processor Metrics:")
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
mm_data = []
|
||||
for stage, metrics in result["mm_processor_stats"].items():
|
||||
row = {
|
||||
"Stage": stage,
|
||||
"Mean": f"{metrics['mean']:.2f}",
|
||||
"Median": f"{metrics['median']:.2f}",
|
||||
"Std": f"{metrics['std']:.2f}",
|
||||
}
|
||||
for p in selected_percentiles:
|
||||
row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
|
||||
mm_data.append(row)
|
||||
|
||||
mm_df = pd.DataFrame(mm_data)
|
||||
print(mm_df.to_string(index=False))
|
||||
|
||||
if "encoder_summary" in result and result["encoder_summary"]:
|
||||
total_calls = result["encoder_summary"]["total_encoder_calls"]
|
||||
num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
|
||||
print(
|
||||
f"\nSummary: {total_calls} total encoder calls "
|
||||
f"across {num_requests} requests."
|
||||
)
|
||||
|
||||
if "mean_e2el_ms" in result:
|
||||
print("\nEnd-to-End Latency (ms):")
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
|
||||
e2el_data = [
|
||||
{"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
|
||||
{"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
|
||||
{"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
|
||||
]
|
||||
|
||||
for p in selected_percentiles:
|
||||
percentile_value = next(
|
||||
(val for pct, val in result["percentiles_e2el_ms"] if pct == p),
|
||||
0.0,
|
||||
)
|
||||
e2el_data.append(
|
||||
{
|
||||
"Metric": f"P{p}",
|
||||
"Value (ms)": f"{percentile_value:.2f}",
|
||||
}
|
||||
)
|
||||
|
||||
e2el_df = pd.DataFrame(e2el_data)
|
||||
print(e2el_df.to_string(index=False))
|
||||
|
||||
if args.output_json:
|
||||
result["config"] = {
|
||||
"model": args.model,
|
||||
"num_prompts": args.num_prompts,
|
||||
"input_len": getattr(args, "random_input_len", None),
|
||||
"output_len": getattr(args, "random_output_len", None),
|
||||
}
|
||||
result["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"\nResults saved to {args.output_json}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark mm processor latency")
|
||||
add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
1816
vllm/benchmarks/serve.py
Normal file
1816
vllm/benchmarks/serve.py
Normal file
File diff suppressed because it is too large
Load Diff
321
vllm/benchmarks/startup.py
Normal file
321
vllm/benchmarks/startup.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the cold and warm startup time of vLLM models.
|
||||
|
||||
This script measures total startup time (including model loading, compilation,
|
||||
and cache operations) for both cold and warm scenarios:
|
||||
- Cold startup: Fresh start with no caches (temporary cache directories)
|
||||
- Warm startup: Using cached compilation and model info
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import (
|
||||
convert_to_pytorch_benchmark_format,
|
||||
write_to_json,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cold_startup():
|
||||
"""
|
||||
Context manager to measure cold startup time:
|
||||
1. Uses a temporary directory for vLLM cache to avoid any pollution
|
||||
between cold startup iterations.
|
||||
2. Uses inductor's fresh_cache to clear torch.compile caches.
|
||||
"""
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
# Use temporary directory for caching to avoid any pollution between cold startups
|
||||
original_cache_root = os.environ.get("VLLM_CACHE_ROOT")
|
||||
temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_")
|
||||
try:
|
||||
os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir
|
||||
with fresh_cache():
|
||||
yield
|
||||
finally:
|
||||
# Clean up temporary cache directory
|
||||
shutil.rmtree(temp_cache_dir, ignore_errors=True)
|
||||
if original_cache_root:
|
||||
os.environ["VLLM_CACHE_ROOT"] = original_cache_root
|
||||
else:
|
||||
os.environ.pop("VLLM_CACHE_ROOT", None)
|
||||
|
||||
|
||||
def run_startup_in_subprocess(engine_args, result_queue):
|
||||
"""
|
||||
Run LLM startup in a subprocess and return timing metrics via a queue.
|
||||
This ensures complete isolation between iterations.
|
||||
"""
|
||||
try:
|
||||
# Import inside the subprocess to avoid issues with forking
|
||||
from vllm import LLM
|
||||
|
||||
# Measure total startup time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
total_startup_time = time.perf_counter() - start_time
|
||||
|
||||
# Extract compilation time if available
|
||||
compilation_time = 0.0
|
||||
if hasattr(llm.llm_engine, "vllm_config"):
|
||||
vllm_config = llm.llm_engine.vllm_config
|
||||
if (
|
||||
hasattr(vllm_config, "compilation_config")
|
||||
and vllm_config.compilation_config is not None
|
||||
):
|
||||
compilation_time = vllm_config.compilation_config.compilation_time
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"total_startup_time": total_startup_time,
|
||||
"compilation_time": compilation_time,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put(None)
|
||||
result_queue.put(str(e))
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
base_name = os.path.splitext(args.output_json)[0]
|
||||
|
||||
cold_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_startup_time": [results["avg_cold_startup_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"cold_startup_times": results["cold_startup_times"],
|
||||
"cold_startup_percentiles": results["cold_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_startup_records:
|
||||
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
|
||||
|
||||
cold_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_compilation_time": [results["avg_cold_compilation_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"cold_compilation_times": results["cold_compilation_times"],
|
||||
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
|
||||
)
|
||||
|
||||
warm_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_startup_time": [results["avg_warm_startup_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"warm_startup_times": results["warm_startup_times"],
|
||||
"warm_startup_percentiles": results["warm_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_startup_records:
|
||||
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
|
||||
|
||||
warm_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_compilation_time": [results["avg_warm_compilation_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"warm_compilation_times": results["warm_compilation_times"],
|
||||
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-iters-cold",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of cold startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup iterations before benchmarking warm startups.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warm",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of warm startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the startup time results in JSON format.",
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Set multiprocessing start method to 'spawn' for clean process isolation
|
||||
# This ensures each subprocess starts fresh without inheriting state
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
def create_llm_and_measure_startup():
|
||||
"""
|
||||
Create LLM instance in a subprocess and measure startup time.
|
||||
Returns timing metrics, using subprocess for complete isolation.
|
||||
"""
|
||||
|
||||
# Create a queue for inter-process communication
|
||||
result_queue = multiprocessing.Queue()
|
||||
process = multiprocessing.Process(
|
||||
target=run_startup_in_subprocess,
|
||||
args=(
|
||||
engine_args,
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if not result_queue.empty():
|
||||
result = result_queue.get()
|
||||
if result is None:
|
||||
if not result_queue.empty():
|
||||
error_msg = result_queue.get()
|
||||
raise RuntimeError(f"Subprocess failed: {error_msg}")
|
||||
else:
|
||||
raise RuntimeError("Subprocess failed with unknown error")
|
||||
return result
|
||||
else:
|
||||
raise RuntimeError("Subprocess did not return a result")
|
||||
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
|
||||
|
||||
print("Measuring cold startup time...\n")
|
||||
cold_startup_times = []
|
||||
cold_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
|
||||
with cold_startup():
|
||||
metrics = create_llm_and_measure_startup()
|
||||
cold_startup_times.append(metrics["total_startup_time"])
|
||||
cold_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Warmup for warm startup
|
||||
print("\nWarming up for warm startup measurement...\n")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
create_llm_and_measure_startup()
|
||||
|
||||
print("\nMeasuring warm startup time...\n")
|
||||
warm_startup_times = []
|
||||
warm_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
|
||||
metrics = create_llm_and_measure_startup()
|
||||
warm_startup_times.append(metrics["total_startup_time"])
|
||||
warm_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Calculate statistics
|
||||
cold_startup_array = np.array(cold_startup_times)
|
||||
cold_compilation_array = np.array(cold_compilation_times)
|
||||
warm_startup_array = np.array(warm_startup_times)
|
||||
warm_compilation_array = np.array(warm_compilation_times)
|
||||
|
||||
avg_cold_startup = np.mean(cold_startup_array)
|
||||
avg_cold_compilation = np.mean(cold_compilation_array)
|
||||
avg_warm_startup = np.mean(warm_startup_array)
|
||||
avg_warm_compilation = np.mean(warm_compilation_array)
|
||||
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
|
||||
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
|
||||
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
|
||||
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("STARTUP TIME BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Cold startup statistics
|
||||
print("\nCOLD STARTUP:")
|
||||
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
# Warm startup statistics
|
||||
print("\nWARM STARTUP:")
|
||||
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_cold_startup_time": float(avg_cold_startup),
|
||||
"avg_cold_compilation_time": float(avg_cold_compilation),
|
||||
"cold_startup_times": cold_startup_times,
|
||||
"cold_compilation_times": cold_compilation_times,
|
||||
"cold_startup_percentiles": dict(
|
||||
zip(percentages, cold_startup_percentiles.tolist())
|
||||
),
|
||||
"cold_compilation_percentiles": dict(
|
||||
zip(percentages, cold_compilation_percentiles.tolist())
|
||||
),
|
||||
"avg_warm_startup_time": float(avg_warm_startup),
|
||||
"avg_warm_compilation_time": float(avg_warm_compilation),
|
||||
"warm_startup_times": warm_startup_times,
|
||||
"warm_compilation_times": warm_compilation_times,
|
||||
"warm_startup_percentiles": dict(
|
||||
zip(percentages, warm_startup_percentiles.tolist())
|
||||
),
|
||||
"warm_compilation_percentiles": dict(
|
||||
zip(percentages, warm_compilation_percentiles.tolist())
|
||||
),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
0
vllm/benchmarks/sweep/__init__.py
Normal file
0
vllm/benchmarks/sweep/__init__.py
Normal file
44
vllm/benchmarks/sweep/cli.py
Normal file
44
vllm/benchmarks/sweep/cli.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
from .plot import SweepPlotArgs
|
||||
from .plot import main as plot_main
|
||||
from .plot_pareto import SweepPlotParetoArgs
|
||||
from .plot_pareto import main as plot_pareto_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
from .startup import SweepStartupArgs
|
||||
from .startup import main as startup_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepStartupArgs, startup_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
(SweepPlotParetoArgs, plot_pareto_main),
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
|
||||
|
||||
for cmd, entrypoint in SUBCOMMANDS:
|
||||
cmd_subparser = subparsers.add_parser(
|
||||
cmd.parser_name,
|
||||
description=cmd.parser_help,
|
||||
usage=f"vllm bench sweep {cmd.parser_name} [options]",
|
||||
)
|
||||
cmd_subparser.set_defaults(dispatch_function=entrypoint)
|
||||
cmd.add_cli_args(cmd_subparser)
|
||||
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=f"sweep {cmd.parser_name}"
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
args.dispatch_function(args)
|
||||
159
vllm/benchmarks/sweep/param_sweep.py
Normal file
159
vllm/benchmarks/sweep/param_sweep.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ParameterSweep(list["ParameterSweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Support both list and dict formats
|
||||
if isinstance(data, dict):
|
||||
return cls.read_from_dict(data)
|
||||
|
||||
return cls.from_records(data)
|
||||
|
||||
@classmethod
|
||||
def read_from_dict(cls, data: dict[str, dict[str, object]]):
|
||||
"""
|
||||
Read parameter sweep from a dict format where keys are names.
|
||||
|
||||
Example:
|
||||
{
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9}
|
||||
}
|
||||
"""
|
||||
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
def from_records(cls, records: list[dict[str, object]]):
|
||||
if not isinstance(records, list):
|
||||
raise TypeError(
|
||||
f"The parameter sweep should be a list of dictionaries, "
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
# Validate that all _benchmark_name values are unique if provided
|
||||
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
|
||||
if names and len(names) != len(set(names)):
|
||||
duplicates = [name for name in names if names.count(name) > 1]
|
||||
raise ValueError(
|
||||
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
|
||||
f"All _benchmark_name values must be unique."
|
||||
)
|
||||
|
||||
return cls(ParameterSweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
class ParameterSweepItem(dict[str, object]):
|
||||
@classmethod
|
||||
def from_record(cls, record: dict[str, object]):
|
||||
if not isinstance(record, dict):
|
||||
raise TypeError(
|
||||
f"Each item in the parameter sweep should be a dictionary, "
|
||||
f"but found type: {type(record)}"
|
||||
)
|
||||
|
||||
return cls(record)
|
||||
|
||||
def __or__(self, other: dict[str, Any]):
|
||||
return type(self)(super().__or__(other))
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name for this parameter sweep item.
|
||||
|
||||
Returns the '_benchmark_name' field if present, otherwise returns a text
|
||||
representation of all parameters.
|
||||
"""
|
||||
if "_benchmark_name" in self:
|
||||
return str(self["_benchmark_name"])
|
||||
|
||||
return self.as_text(sep="-")
|
||||
|
||||
# In JSON, we prefer "_"
|
||||
def _iter_param_key_candidates(self, param_key: str):
|
||||
# Inner config arguments are not converted by the CLI
|
||||
if "." in param_key:
|
||||
prefix, rest = param_key.split(".", 1)
|
||||
for prefix_candidate in self._iter_param_key_candidates(prefix):
|
||||
yield prefix_candidate + "." + rest
|
||||
|
||||
return
|
||||
|
||||
yield param_key
|
||||
yield param_key.replace("-", "_")
|
||||
yield param_key.replace("_", "-")
|
||||
|
||||
# In CLI, we prefer "-"
|
||||
def _iter_cmd_key_candidates(self, param_key: str):
|
||||
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
|
||||
yield "--" + k
|
||||
|
||||
def _normalize_cmd_key(self, param_key: str):
|
||||
return next(self._iter_cmd_key_candidates(param_key))
|
||||
|
||||
def has_param(self, param_key: str) -> bool:
|
||||
return any(k in self for k in self._iter_param_key_candidates(param_key))
|
||||
|
||||
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
|
||||
"""
|
||||
Normalize a key-value pair into command-line arguments.
|
||||
|
||||
Returns a list containing either:
|
||||
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
|
||||
- Two elements for key-value pairs (e.g., ['--key', 'value'])
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
# For nested params (containing "."), use =true/false syntax
|
||||
if "." in k:
|
||||
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k if v else "no-" + k)]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k), str(v)]
|
||||
|
||||
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
|
||||
cmd = list(cmd)
|
||||
|
||||
for k, v in self.items():
|
||||
# Skip the '_benchmark_name' field, not a parameter
|
||||
if k == "_benchmark_name":
|
||||
continue
|
||||
|
||||
# Serialize dict values as JSON
|
||||
if isinstance(v, dict):
|
||||
v = json.dumps(v)
|
||||
|
||||
for k_candidate in self._iter_cmd_key_candidates(k):
|
||||
try:
|
||||
k_idx = cmd.index(k_candidate)
|
||||
|
||||
# Replace existing parameter
|
||||
normalized = self._normalize_cmd_kv_pair(k, v)
|
||||
if len(normalized) == 1:
|
||||
# Boolean flag
|
||||
cmd[k_idx] = normalized[0]
|
||||
else:
|
||||
# Key-value pair
|
||||
cmd[k_idx] = normalized[0]
|
||||
cmd[k_idx + 1] = normalized[1]
|
||||
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
# Add new parameter
|
||||
cmd.extend(self._normalize_cmd_kv_pair(k, v))
|
||||
|
||||
return cmd
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")
|
||||
683
vllm/benchmarks/sweep/plot.py
Normal file
683
vllm/benchmarks/sweep/plot.py
Normal file
@@ -0,0 +1,683 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import ClassVar
|
||||
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotFilterBase(ABC):
|
||||
var: str
|
||||
target: str
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_FILTERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_FILTERS[op_key](
|
||||
key,
|
||||
value.removeprefix(op_key).strip("'").strip('"'),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot filter '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_FILTERS)}",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this filter to a DataFrame."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] == target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotNotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] != target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] < float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] <= float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] > float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] >= float(self.target)]
|
||||
|
||||
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
|
||||
"==": PlotEqualTo,
|
||||
"!=": PlotNotEqualTo,
|
||||
"<=": PlotLessThanOrEqualTo,
|
||||
">=": PlotGreaterThanOrEqualTo,
|
||||
"<": PlotLessThan,
|
||||
">": PlotGreaterThan,
|
||||
}
|
||||
|
||||
|
||||
class PlotFilters(list[PlotFilterBase]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotBinner:
|
||||
var: str
|
||||
bin_size: float
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_BINNERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot binner '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_BINNERS)}",
|
||||
)
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this binner to a DataFrame."""
|
||||
df = df.copy()
|
||||
df[self.var] = df[self.var] // self.bin_size * self.bin_size
|
||||
return df
|
||||
|
||||
|
||||
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
|
||||
"%": PlotBinner,
|
||||
}
|
||||
|
||||
|
||||
class PlotBinners(list[PlotBinner]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotBinner.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
|
||||
with path.open("rb") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Convert string values "inf", "-inf", and "nan" to their float equivalents.
|
||||
|
||||
This handles the case where JSON serialization represents inf/nan as strings.
|
||||
"""
|
||||
converted_data = []
|
||||
for record in data:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
if isinstance(value, str):
|
||||
if value in ["inf", "-inf", "nan"]:
|
||||
converted_record[key] = float(value)
|
||||
else:
|
||||
converted_record[key] = value
|
||||
else:
|
||||
converted_record[key] = value
|
||||
converted_data.append(converted_record)
|
||||
return converted_data
|
||||
|
||||
|
||||
def _get_metric(run_data: dict[str, object], metric_key: str):
|
||||
try:
|
||||
return run_data[metric_key]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
|
||||
|
||||
|
||||
def _get_group(run_data: dict[str, object], group_keys: list[str]):
|
||||
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
|
||||
|
||||
|
||||
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
|
||||
parts = list[str]()
|
||||
|
||||
# Start with figure name (always provided, defaults to "FIGURE")
|
||||
parts.append(fig_name)
|
||||
|
||||
# Always append group data if present
|
||||
if group:
|
||||
parts.extend(f"{k}={v}" for k, v in group)
|
||||
|
||||
return fig_dir / sanitize_filename("-".join(parts) + ".png")
|
||||
|
||||
|
||||
class DummyExecutor:
|
||||
map = map
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str,
|
||||
error_bars: bool,
|
||||
fig_height: float,
|
||||
fig_dpi: int,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
|
||||
row_groups = full_groupby(
|
||||
fig_data,
|
||||
key=lambda item: _get_group(item, row_by),
|
||||
)
|
||||
num_rows = len(row_groups)
|
||||
num_cols = max(
|
||||
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
|
||||
for _, row_data in row_groups
|
||||
)
|
||||
|
||||
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Grid: {num_rows} rows x {num_cols} cols")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Convert string "inf", "-inf", and "nan" to their float equivalents
|
||||
fig_data = _convert_inf_nan_strings(fig_data)
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
|
||||
if var_x not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_x=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
if var_y not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_y=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in row_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find row_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in col_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find col_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in curve_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find curve_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
|
||||
df["row_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in row_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if row_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
df["col_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in col_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if col_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
if len(curve_by) <= 3:
|
||||
hue, style, size, *_ = (*curve_by, None, None, None)
|
||||
|
||||
g = sns.relplot(
|
||||
df,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue=hue,
|
||||
style=style,
|
||||
size=size,
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
kind="line",
|
||||
row="row_group",
|
||||
col="col_group",
|
||||
height=fig_height,
|
||||
)
|
||||
else:
|
||||
df["curve_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in curve_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if curve_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g = sns.relplot(
|
||||
df,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue="curve_group",
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
kind="line",
|
||||
row="row_group",
|
||||
col="col_group",
|
||||
height=fig_height,
|
||||
)
|
||||
|
||||
if row_by and col_by:
|
||||
g.set_titles("{row_name}\n{col_name}")
|
||||
elif row_by:
|
||||
g.set_titles("{row_name}")
|
||||
elif col_by:
|
||||
g.set_titles("{col_name}")
|
||||
else:
|
||||
g.set_titles("")
|
||||
|
||||
if scale_x:
|
||||
g.set(xscale=scale_x)
|
||||
if scale_y:
|
||||
g.set(yscale=scale_y)
|
||||
|
||||
g.savefig(fig_path, dpi=fig_dpi)
|
||||
plt.close(g.figure)
|
||||
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot(
|
||||
output_dir: Path,
|
||||
fig_dir: Path,
|
||||
fig_by: list[str],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str = "FIGURE",
|
||||
error_bars: bool = True,
|
||||
fig_height: float = 6.4,
|
||||
fig_dpi: int = 300,
|
||||
):
|
||||
all_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
all_data,
|
||||
key=lambda item: _get_group(item, fig_by),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
# Resolve the iterable to ensure that the workers are run
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=var_x,
|
||||
var_y=var_y,
|
||||
filter_by=filter_by,
|
||||
bin_by=bin_by,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
dry_run=dry_run,
|
||||
fig_name=fig_name,
|
||||
error_bars=error_bars,
|
||||
fig_height=fig_height,
|
||||
fig_dpi=fig_dpi,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotArgs:
|
||||
output_dir: Path
|
||||
fig_dir: Path
|
||||
fig_by: list[str]
|
||||
row_by: list[str]
|
||||
col_by: list[str]
|
||||
curve_by: list[str]
|
||||
var_x: str
|
||||
var_y: str
|
||||
filter_by: PlotFilters
|
||||
bin_by: PlotBinners
|
||||
scale_x: str | None
|
||||
scale_y: str | None
|
||||
dry_run: bool
|
||||
fig_name: str = "FIGURE"
|
||||
error_bars: bool = True
|
||||
fig_height: float = 6.4
|
||||
fig_dpi: int = 300
|
||||
|
||||
parser_name: ClassVar[str] = "plot"
|
||||
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
curve_by = [] if not args.curve_by else args.curve_by.split(",")
|
||||
row_by = [] if not args.row_by else args.row_by.split(",")
|
||||
col_by = [] if not args.col_by else args.col_by.split(",")
|
||||
fig_by = [] if not args.fig_by else args.fig_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
fig_dir=output_dir / args.fig_dir,
|
||||
fig_by=fig_by,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=PlotFilters.parse_str(args.filter_by),
|
||||
bin_by=PlotBinners.parse_str(args.bin_by),
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=not args.no_error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
|
||||
"By default, the same directory is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate figure "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--row-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate row "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--col-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate column "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curve-by",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A comma-separated list of variables, such that a separate curve "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_ttft_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to filter by. "
|
||||
"This is useful to remove outliers. "
|
||||
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
|
||||
"plot only the points where `max_concurrency` is less than 1000 and "
|
||||
"`max_num_batched_tokens` is no greater than 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to bin by. "
|
||||
"This is useful to avoid plotting points that are too close together. "
|
||||
"Example: `request_throughput%%1` means "
|
||||
"use a bin size of 1 for the `request_throughput` variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-x",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the x-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-y",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the y-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-name",
|
||||
type=str,
|
||||
default="FIGURE",
|
||||
help="Name prefix for the output figure file. "
|
||||
"Group data is always appended when present. "
|
||||
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-error-bars",
|
||||
action="store_true",
|
||||
help="If set, disables error bars on the plot. "
|
||||
"By default, error bars are shown.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-height",
|
||||
type=float,
|
||||
default=6.4,
|
||||
help="Height of each subplot in inches. Default: 6.4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dpi",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Resolution of the output figure in dots per inch. Default: 300",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the information about each figure to plot, "
|
||||
"then exits without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotArgs):
|
||||
return plot(
|
||||
output_dir=args.output_dir,
|
||||
fig_dir=args.fig_dir,
|
||||
fig_by=args.fig_by,
|
||||
row_by=args.row_by,
|
||||
col_by=args.col_by,
|
||||
curve_by=args.curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=args.filter_by,
|
||||
bin_by=args.bin_by,
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=args.error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
|
||||
SweepPlotArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
399
vllm/benchmarks/sweep/plot_pareto.py
Normal file
399
vllm/benchmarks/sweep/plot_pareto.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .plot import DummyExecutor, _json_load_bytes
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
def _first_present(run_data: dict[str, object], keys: list[str]):
|
||||
for key in keys:
|
||||
for candidate in {key, key.replace("_", "-"), key.replace("-", "_")}:
|
||||
if candidate in run_data:
|
||||
return run_data[candidate]
|
||||
return None
|
||||
|
||||
|
||||
def _get_numeric(
|
||||
run_data: dict[str, object],
|
||||
keys: list[str],
|
||||
*,
|
||||
allow_zero: bool = True,
|
||||
) -> float | None:
|
||||
value = _first_present(run_data, keys)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(
|
||||
f"Expected numeric value for one of {keys}, "
|
||||
f"but found {value!r} in {run_data=}"
|
||||
) from exc
|
||||
|
||||
if not allow_zero and numeric == 0:
|
||||
return None
|
||||
|
||||
return numeric
|
||||
|
||||
|
||||
def _infer_user_count(
|
||||
run_data: dict[str, object],
|
||||
user_count_var: str | None,
|
||||
) -> float | None:
|
||||
candidates = [user_count_var] if user_count_var else []
|
||||
candidates.extend(["request_rate"])
|
||||
user_count = _get_numeric(run_data, candidates, allow_zero=False)
|
||||
if user_count is not None:
|
||||
return user_count
|
||||
|
||||
# Fallback to the observed peak if configured value is missing.
|
||||
return _get_numeric(run_data, ["max_concurrent_requests"], allow_zero=False)
|
||||
|
||||
|
||||
def _infer_gpu_count(
|
||||
run_data: dict[str, object],
|
||||
gpu_count_var: str | None,
|
||||
) -> float:
|
||||
direct_candidates = [gpu_count_var] if gpu_count_var else []
|
||||
direct_gpu_count = _get_numeric(run_data, direct_candidates, allow_zero=False)
|
||||
if direct_gpu_count:
|
||||
return direct_gpu_count
|
||||
|
||||
tp_size = _get_numeric(run_data, ["tensor_parallel_size", "tp"])
|
||||
pp_size = _get_numeric(run_data, ["pipeline_parallel_size", "pp"])
|
||||
dp_size = _get_numeric(run_data, ["data_parallel_size", "dp"])
|
||||
world_size = 1.0
|
||||
if tp_size:
|
||||
world_size *= tp_size
|
||||
if pp_size:
|
||||
world_size *= pp_size
|
||||
if dp_size:
|
||||
world_size *= dp_size
|
||||
|
||||
return world_size
|
||||
|
||||
|
||||
def _get_throughput(
|
||||
run_data: dict[str, object],
|
||||
throughput_var: str,
|
||||
) -> float:
|
||||
throughput = _get_numeric(run_data, [throughput_var])
|
||||
if throughput is None:
|
||||
raise ValueError(
|
||||
f"Cannot find throughput metric {throughput_var!r} in run data. "
|
||||
f"Available keys: {sorted(run_data)}"
|
||||
)
|
||||
|
||||
return throughput
|
||||
|
||||
|
||||
def _prepare_records(
|
||||
all_data: list[dict[str, object]],
|
||||
*,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
) -> tuple[list[dict[str, object]], int]:
|
||||
prepared = []
|
||||
skipped_missing_users = 0
|
||||
|
||||
for record in all_data:
|
||||
throughput = _get_throughput(record, "output_throughput")
|
||||
user_count = _infer_user_count(record, user_count_var)
|
||||
if user_count is None:
|
||||
skipped_missing_users += 1
|
||||
continue
|
||||
|
||||
gpu_count = _infer_gpu_count(record, gpu_count_var)
|
||||
tokens_per_user = throughput / user_count
|
||||
tokens_per_gpu = throughput / gpu_count
|
||||
|
||||
prepared.append(
|
||||
{
|
||||
**record,
|
||||
"tokens_per_user": tokens_per_user,
|
||||
"tokens_per_gpu": tokens_per_gpu,
|
||||
"user_count_estimate": user_count,
|
||||
"gpu_count": gpu_count,
|
||||
}
|
||||
)
|
||||
|
||||
return prepared, skipped_missing_users
|
||||
|
||||
|
||||
def _pareto_frontier(
|
||||
df: "pd.DataFrame",
|
||||
x_col: str,
|
||||
y_col: str,
|
||||
*,
|
||||
epsilon: float = 1e-9,
|
||||
) -> "pd.DataFrame":
|
||||
sorted_df = df.sort_values([x_col, y_col], ascending=[False, False])
|
||||
frontier_indices = []
|
||||
best_y = -math.inf
|
||||
|
||||
for idx, row in sorted_df.iterrows():
|
||||
y_val = row[y_col]
|
||||
if y_val >= best_y - epsilon:
|
||||
frontier_indices.append(idx)
|
||||
best_y = max(best_y, y_val)
|
||||
|
||||
return df.loc[frontier_indices]
|
||||
|
||||
|
||||
def _get_fig_path(
|
||||
fig_dir: Path,
|
||||
fig_group: tuple[tuple[str, str], ...],
|
||||
) -> Path:
|
||||
parts = ["PARETO"]
|
||||
if fig_group:
|
||||
parts.extend(f"{k}={v}" for k, v in fig_group)
|
||||
filename = sanitize_filename("-".join(parts) + ".png")
|
||||
return fig_dir / filename
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
fig_path = _get_fig_path(fig_dir, fig_group)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
df = df.dropna(subset=["tokens_per_user", "tokens_per_gpu"])
|
||||
|
||||
if df.empty:
|
||||
print("No data points available after filtering; skipping.")
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
frontier = _pareto_frontier(df, "tokens_per_user", "tokens_per_gpu")
|
||||
frontier = frontier.sort_values("tokens_per_user")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
sns.scatterplot(
|
||||
data=df,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
color="0.5",
|
||||
alpha=0.6,
|
||||
ax=ax,
|
||||
label="All runs",
|
||||
)
|
||||
sns.lineplot(
|
||||
data=frontier,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
marker="o",
|
||||
ax=ax,
|
||||
label="Pareto frontier",
|
||||
)
|
||||
|
||||
if label_by:
|
||||
for _, row in frontier.iterrows():
|
||||
label_parts = []
|
||||
for key in label_by:
|
||||
if key in row:
|
||||
label_parts.append(f"{key}={row[key]}")
|
||||
if label_parts:
|
||||
ax.text(
|
||||
row["tokens_per_user"],
|
||||
row["tokens_per_gpu"],
|
||||
"\n".join(label_parts),
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Tokens/s/user")
|
||||
ax.set_ylabel("Tokens/s/GPU")
|
||||
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
|
||||
fig.tight_layout()
|
||||
fig.savefig(fig_path)
|
||||
plt.close(fig)
|
||||
|
||||
print(
|
||||
f"Plotted {len(df)} points; Pareto frontier size: {len(frontier)}.",
|
||||
)
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot_pareto(
|
||||
output_dir: Path,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_dir = output_dir / "pareto"
|
||||
raw_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not raw_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prepared_data, skipped_missing_users = _prepare_records(
|
||||
raw_data,
|
||||
user_count_var=user_count_var,
|
||||
gpu_count_var=gpu_count_var,
|
||||
)
|
||||
|
||||
if skipped_missing_users:
|
||||
print(
|
||||
f"Skipped {skipped_missing_users} runs without a user count "
|
||||
"(`max_concurrency` or `max_concurrent_requests`).",
|
||||
)
|
||||
|
||||
if not prepared_data:
|
||||
raise ValueError(
|
||||
"No data points with both throughput and user count available "
|
||||
"to plot Pareto frontier.",
|
||||
)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
prepared_data,
|
||||
key=lambda item: tuple(),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
label_by=label_by,
|
||||
dry_run=dry_run,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotParetoArgs:
|
||||
output_dir: Path
|
||||
user_count_var: str | None
|
||||
gpu_count_var: str | None
|
||||
label_by: list[str]
|
||||
dry_run: bool
|
||||
|
||||
parser_name: ClassVar[str] = "plot_pareto"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Plot Pareto frontier between tokens/s/user and tokens/s/GPU "
|
||||
"from parameter sweep results."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
label_by = [] if not args.label_by else args.label_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user-count-var",
|
||||
type=str,
|
||||
default="max_concurrency",
|
||||
help="Result key that stores concurrent user count. "
|
||||
"Falls back to max_concurrent_requests if missing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-count-var",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Result key that stores GPU count. "
|
||||
"If not provided, falls back to num_gpus/gpu_count "
|
||||
"or tensor_parallel_size * pipeline_parallel_size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label-by",
|
||||
type=str,
|
||||
default="max_concurrency,gpu_count",
|
||||
help="Comma-separated list of fields to annotate on Pareto frontier "
|
||||
"points.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the figures to plot without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotParetoArgs):
|
||||
return plot_pareto(
|
||||
output_dir=args.output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=args.label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotParetoArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotParetoArgs.parser_help)
|
||||
SweepPlotParetoArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
498
vllm/benchmarks/sweep/serve.py
Normal file
498
vllm/benchmarks/sweep/serve.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .server import ServerProcess
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_server(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
|
||||
|
||||
print("[BEGIN SERVER]")
|
||||
print(f"Server overrides: {serve_overrides}")
|
||||
print(f"Server command: {server_cmd}")
|
||||
|
||||
if dry_run:
|
||||
yield None
|
||||
print("[END SERVER]")
|
||||
return
|
||||
|
||||
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
|
||||
server.wait_until_ready(timeout=server_ready_timeout)
|
||||
yield server
|
||||
|
||||
print("[END SERVER]")
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
):
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(bench_overrides)
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
dry_run: bool,
|
||||
):
|
||||
benchmark_cmd = [
|
||||
*bench_overrides.apply_to_cmd(bench_cmd),
|
||||
"--percentile-metrics",
|
||||
"ttft,tpot,itl,e2el",
|
||||
"--save-result",
|
||||
"--result-dir",
|
||||
str(output_path.parent),
|
||||
"--result-filename",
|
||||
output_path.name,
|
||||
]
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Benchmark overrides: {bench_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {benchmark_cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
run_data: dict[str, object]
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results.")
|
||||
print("[SKIPPED BENCHMARK]")
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
if server is None:
|
||||
if not dry_run:
|
||||
raise ValueError(f"Cannot find results at {output_path}")
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
server.run_subcommand(benchmark_cmd)
|
||||
server.after_bench()
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
with output_path.open("w") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def server_ctx(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
if not _comb_needs_server(serve_comb, bench_params, output_dir):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return run_server(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_overrides=serve_comb,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
)
|
||||
|
||||
|
||||
def _comb_is_valid(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
) -> bool:
|
||||
return all(
|
||||
serve_key in serve_comb
|
||||
and bench_key in bench_comb
|
||||
and serve_comb[serve_key] == bench_comb[bench_key]
|
||||
for serve_key, bench_key in link_vars
|
||||
)
|
||||
|
||||
|
||||
def run_comb(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
|
||||
return None
|
||||
|
||||
comb_data = list[dict[str, object]]()
|
||||
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
bench_overrides=bench_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeArgs:
|
||||
serve_cmd: list[str]
|
||||
bench_cmd: list[str]
|
||||
after_bench_cmd: list[str]
|
||||
show_stdout: bool
|
||||
server_ready_timeout: int
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
link_vars: list[tuple[str, str]]
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
serve_cmd = shlex.split(args.serve_cmd)
|
||||
bench_cmd = shlex.split(args.bench_cmd)
|
||||
after_bench_cmd = (
|
||||
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
|
||||
)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
# i.e.: run serve_cmd without any modification
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.bench_params:
|
||||
bench_params = ParameterSweep.read_json(args.bench_params)
|
||||
else:
|
||||
# i.e.: run bench_cmd without any modification
|
||||
bench_params = ParameterSweep.from_records([{}])
|
||||
|
||||
link_vars = cls.parse_link_vars(args.link_vars)
|
||||
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
serve_cmd=serve_cmd,
|
||||
bench_cmd=bench_cmd,
|
||||
after_bench_cmd=after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
link_vars=link_vars,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--serve-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the server: `vllm serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the benchmark: `vllm bench serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--after-bench-cmd",
|
||||
type=str,
|
||||
default=None,
|
||||
help="After a benchmark run is complete, invoke this command instead of "
|
||||
"the default `ServerWrapper.clear_cache()`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands. "
|
||||
"Useful for debugging but can be quite spammy.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-ready-timeout",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Timeout in seconds to wait for the server to become ready.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Can be either a list of dicts or a dict "
|
||||
"where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench serve` command. Can be either a list of dicts or "
|
||||
"a dict where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def parse_link_vars(s: str) -> list[tuple[str, str]]:
|
||||
if not s:
|
||||
return []
|
||||
pairs = []
|
||||
for item in s.split(","):
|
||||
a, b = item.split("=")
|
||||
pairs.append((a.strip(), b.strip()))
|
||||
return pairs
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
|
||||
SweepServeArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
305
vllm/benchmarks/sweep/serve_sla.py
Normal file
305
vllm/benchmarks/sweep/serve_sla.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
|
||||
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
|
||||
|
||||
|
||||
def run_comb_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
sla_variable: SLAVariable,
|
||||
sla_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_sla = bench_comb | {sla_variable: sla_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_sla,
|
||||
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
|
||||
def explore_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of SLA iterations: {sla_iters}")
|
||||
|
||||
if sla_iters < 2:
|
||||
raise ValueError("`sla_iters` should be at least 2")
|
||||
|
||||
serial_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=1,
|
||||
)
|
||||
batch_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
|
||||
)
|
||||
|
||||
if serial_comb_data is None or batch_comb_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate SLA iterations.")
|
||||
print("[SLA END]")
|
||||
|
||||
return
|
||||
|
||||
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
|
||||
print(f"Serial inference: {sla_variable}={serial_sla_value}")
|
||||
|
||||
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
|
||||
print(f"Batch inference: {sla_variable}={batch_sla_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_sla_value` and `batch_sla_value` is small
|
||||
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
|
||||
inter_sla_values = sorted(set(map(round, inter_sla_values)))
|
||||
|
||||
inter_combs_data: list[dict[str, object]] = []
|
||||
for inter_sla_value in inter_sla_values:
|
||||
print(f"Exploring: {sla_variable}={inter_sla_value}")
|
||||
inter_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=inter_sla_value,
|
||||
)
|
||||
if inter_comb_data is not None:
|
||||
inter_combs_data.extend(inter_comb_data)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return serial_comb_data + inter_combs_data + batch_comb_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_variable=sla_variable,
|
||||
sla_iters=sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_variable: SLAVariable
|
||||
sla_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput space for determining SLAs."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
sla_group.add_argument(
|
||||
"--sla-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations used to explore the latency-throughput space. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`sla_variable` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
142
vllm/benchmarks/sweep/server.py
Normal file
142
vllm/benchmarks/sweep/server.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
from types import TracebackType
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
VLLM_RESET_CACHE_ENDPOINTS = [
|
||||
"/reset_prefix_cache",
|
||||
"/reset_mm_cache",
|
||||
"/reset_encoder_cache",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.server_cmd = server_cmd
|
||||
self.after_bench_cmd = after_bench_cmd
|
||||
self.show_stdout = show_stdout
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
# Create new process for clean termination
|
||||
self._server_process = subprocess.Popen(
|
||||
self.server_cmd,
|
||||
start_new_session=True,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
|
||||
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
server_process = self._server_process
|
||||
|
||||
if server_process.poll() is None:
|
||||
# In case only some processes have been terminated
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
# We need to kill both API Server and Engine processes
|
||||
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
|
||||
|
||||
def run_subcommand(self, cmd: list[str]):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
def after_bench(self) -> None:
|
||||
if not self.after_bench_cmd:
|
||||
self.reset_caches()
|
||||
return
|
||||
|
||||
self.run_subcommand(self.after_bench_cmd)
|
||||
|
||||
def _get_vllm_server_address(self) -> str:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
for host_key in ("--host",):
|
||||
if host_key in server_cmd:
|
||||
host = server_cmd[server_cmd.index(host_key) + 1]
|
||||
break
|
||||
else:
|
||||
host = "localhost"
|
||||
|
||||
for port_key in ("-p", "--port"):
|
||||
if port_key in server_cmd:
|
||||
port = int(server_cmd[server_cmd.index(port_key) + 1])
|
||||
break
|
||||
else:
|
||||
port = 8000 # The default value in vllm serve
|
||||
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def is_server_ready(self) -> bool:
|
||||
server_address = self._get_vllm_server_address()
|
||||
try:
|
||||
response = requests.get(f"{server_address}/health")
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def wait_until_ready(self, timeout: int) -> None:
|
||||
start_time = time.monotonic()
|
||||
while not self.is_server_ready():
|
||||
# Check if server process has crashed
|
||||
if self._server_process.poll() is not None:
|
||||
returncode = self._server_process.returncode
|
||||
raise RuntimeError(
|
||||
f"Server process crashed with return code {returncode}"
|
||||
)
|
||||
if time.monotonic() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
f"Server failed to become ready within {timeout} seconds."
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
def reset_caches(self) -> None:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
# Use `.endswith()` to match `/bin/...`
|
||||
if server_cmd[0].endswith("vllm"):
|
||||
server_address = self._get_vllm_server_address()
|
||||
print(f"Resetting caches at {server_address}")
|
||||
|
||||
for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS:
|
||||
res = requests.post(server_address + endpoint)
|
||||
res.raise_for_status()
|
||||
elif server_cmd[0].endswith("infinity_emb"):
|
||||
if "--vector-disk-cache" in server_cmd:
|
||||
raise NotImplementedError(
|
||||
"Infinity server uses caching but does not expose a method "
|
||||
"to reset the cache"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
|
||||
"Please specify a custom command via `--after-bench-cmd`."
|
||||
)
|
||||
406
vllm/benchmarks/sweep/startup.py
Normal file
406
vllm/benchmarks/sweep/startup.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.benchmarks.startup import add_cli_args as add_startup_cli_args
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_supported_startup_keys() -> set[str]:
|
||||
parser = FlexibleArgumentParser(add_help=False)
|
||||
add_startup_cli_args(parser)
|
||||
|
||||
supported: set[str] = {"config"}
|
||||
for action in parser._actions:
|
||||
if action.dest and action.dest is not argparse.SUPPRESS:
|
||||
supported.add(action.dest)
|
||||
for option in action.option_strings:
|
||||
if option.startswith("--"):
|
||||
supported.add(option.lstrip("-").replace("-", "_"))
|
||||
|
||||
return supported
|
||||
|
||||
|
||||
def _is_supported_param(param_key: str, supported: set[str]) -> bool:
|
||||
if param_key == "_benchmark_name":
|
||||
return True
|
||||
prefix = param_key.split(".", 1)[0]
|
||||
normalized = prefix.replace("-", "_")
|
||||
return normalized in supported
|
||||
|
||||
|
||||
def _filter_params(
|
||||
params: ParameterSweep, *, supported: set[str], strict: bool
|
||||
) -> ParameterSweep:
|
||||
filtered = []
|
||||
for item in params:
|
||||
kept: dict[str, object] = {}
|
||||
dropped: list[str] = []
|
||||
for key, value in item.items():
|
||||
if _is_supported_param(key, supported):
|
||||
kept[key] = value
|
||||
else:
|
||||
dropped.append(key)
|
||||
|
||||
if dropped:
|
||||
label = item.get("_benchmark_name") or item.as_text()
|
||||
message = (
|
||||
"Ignoring unsupported startup params"
|
||||
f"{' for ' + str(label) if label else ''}: "
|
||||
f"{', '.join(sorted(dropped))}"
|
||||
)
|
||||
if strict:
|
||||
raise ValueError(message)
|
||||
print(message)
|
||||
|
||||
filtered.append(ParameterSweepItem.from_record(kept))
|
||||
|
||||
return ParameterSweep(filtered)
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
startup_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
) -> dict[str, object]:
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(startup_overrides)
|
||||
return run_data
|
||||
|
||||
|
||||
def _strip_arg(cmd: list[str], keys: tuple[str, ...]) -> list[str]:
|
||||
stripped: list[str] = []
|
||||
skip_next = False
|
||||
for arg in cmd:
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if arg in keys:
|
||||
skip_next = True
|
||||
continue
|
||||
if any(arg.startswith(f"{key}=") for key in keys):
|
||||
continue
|
||||
stripped.append(arg)
|
||||
return stripped
|
||||
|
||||
|
||||
def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
|
||||
keys = ("--output-json", "--output_json")
|
||||
cmd = _strip_arg(cmd, keys)
|
||||
return [*cmd, keys[0], str(output_path)]
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
) -> Path:
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if startup_comb:
|
||||
parts.extend(("STARTUP-", startup_comb.name))
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
startup_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> dict[str, object] | None:
|
||||
cmd = serve_overrides.apply_to_cmd(startup_cmd)
|
||||
cmd = startup_overrides.apply_to_cmd(cmd)
|
||||
cmd = _apply_output_json(cmd, output_path)
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Serve overrides: {serve_overrides}")
|
||||
print(f"Startup overrides: {startup_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results.")
|
||||
print("[SKIPPED BENCHMARK]")
|
||||
|
||||
with output_path.open("r", encoding="utf-8") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data, serve_overrides, startup_overrides, run_number
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
subprocess.run(
|
||||
cmd,
|
||||
stdout=None if show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
with output_path.open("r", encoding="utf-8") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data, serve_overrides, startup_overrides, run_number
|
||||
)
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return run_data
|
||||
|
||||
|
||||
def run_comb(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> list[dict[str, object]] | None:
|
||||
comb_data = list[dict[str, object]]()
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
startup_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
startup_overrides=startup_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
show_stdout=show_stdout,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open(
|
||||
"w", encoding="utf-8"
|
||||
) as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_params: ParameterSweep,
|
||||
startup_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> "pd.DataFrame | None":
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
for startup_comb in startup_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
|
||||
comb_data = run_comb(
|
||||
startup_cmd,
|
||||
serve_comb=serve_comb,
|
||||
startup_comb=startup_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
show_stdout=show_stdout,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepStartupArgs:
|
||||
startup_cmd: list[str]
|
||||
serve_params: ParameterSweep
|
||||
startup_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
show_stdout: bool
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
strict_params: bool
|
||||
|
||||
parser_name: ClassVar[str] = "startup"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Benchmark vLLM startup time over parameter combinations."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
startup_cmd = shlex.split(args.startup_cmd)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.startup_params:
|
||||
startup_params = ParameterSweep.read_json(args.startup_params)
|
||||
else:
|
||||
startup_params = ParameterSweep.from_records([{}])
|
||||
|
||||
supported = _get_supported_startup_keys()
|
||||
serve_params = _filter_params(
|
||||
serve_params, supported=supported, strict=args.strict_params
|
||||
)
|
||||
startup_params = _filter_params(
|
||||
startup_params, supported=supported, strict=args.strict_params
|
||||
)
|
||||
|
||||
if args.num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
startup_cmd=startup_cmd,
|
||||
serve_params=serve_params,
|
||||
startup_params=startup_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
strict_params=args.strict_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--startup-cmd",
|
||||
type=str,
|
||||
default="vllm bench startup",
|
||||
help="The command used to run the startup benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Only parameters supported by "
|
||||
"`vllm bench startup` will be applied.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--startup-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench startup` command.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepStartupArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
startup_cmd=args.startup_cmd,
|
||||
serve_params=args.serve_params,
|
||||
startup_params=args.startup_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepStartupArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepStartupArgs.parser_help)
|
||||
SweepStartupArgs.add_cli_args(parser)
|
||||
main(parser.parse_args())
|
||||
4
vllm/benchmarks/sweep/utils.py
Normal file
4
vllm/benchmarks/sweep/utils.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')
|
||||
946
vllm/benchmarks/throughput.py
Normal file
946
vllm/benchmarks/throughput.py
Normal file
@@ -0,0 +1,946 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark offline inference throughput."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
AIMODataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
MultiModalConversationDataset,
|
||||
PrefixRepetitionRandomDataset,
|
||||
RandomDataset,
|
||||
RandomDatasetForReranking,
|
||||
RandomMultiModalDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
add_random_dataset_base_args,
|
||||
add_random_multimodal_dataset_args,
|
||||
)
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput] | None]:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
prompts.append(prompt)
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
do_profile: bool,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = llm.model_config
|
||||
assert all(
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[LoRARequest | None] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
prompts.append(prompt)
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
await llm.start_profile()
|
||||
for i, (prompt, sp, lr) in enumerate(
|
||||
zip(prompts, sampling_params, lora_requests)
|
||||
):
|
||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
if do_profile:
|
||||
await llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: TokenizerLike,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
|
||||
"the hf backend only supports HF tokenizers"
|
||||
)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (
|
||||
max(max_prompt_len, next_prompt_len)
|
||||
+ max(max_output_len, next_output_len)
|
||||
) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"requests_per_second": [results["requests_per_second"]],
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
}
|
||||
|
||||
if args.dataset_name == "random" or (
|
||||
args.dataset_path is None
|
||||
and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
|
||||
):
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
sample_kwargs["prefix_len"] = (
|
||||
random_prefix_len if random_prefix_len is not None else args.prefix_len
|
||||
)
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len if random_input_len is not None else args.input_len
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len if random_output_len is not None else args.output_len
|
||||
)
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
if args.input_len is not None:
|
||||
sample_kwargs["input_len"] = args.input_len
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = MultiModalConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_name == "prefix_repetition":
|
||||
dataset_cls = PrefixRepetitionRandomDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
|
||||
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
|
||||
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
|
||||
sample_kwargs["output_len"] = args.prefix_repetition_output_len
|
||||
elif args.dataset_name == "random-mm":
|
||||
dataset_cls = RandomMultiModalDataset
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len
|
||||
if random_input_len is not None
|
||||
else getattr(args, "input_len", None)
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len
|
||||
if random_output_len is not None
|
||||
else getattr(args, "output_len", None)
|
||||
)
|
||||
sample_kwargs["base_items_per_request"] = getattr(
|
||||
args, "random_mm_base_items_per_request", None
|
||||
)
|
||||
sample_kwargs["num_mm_items_range_ratio"] = getattr(
|
||||
args, "random_mm_num_mm_items_range_ratio", None
|
||||
)
|
||||
sample_kwargs["limit_mm_per_prompt"] = getattr(
|
||||
args, "random_mm_limit_mm_per_prompt", None
|
||||
)
|
||||
sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
prefix_len = getattr(args, "prefix_len", None)
|
||||
sample_kwargs["prefix_len"] = (
|
||||
random_prefix_len if random_prefix_len is not None else prefix_len
|
||||
)
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
elif args.dataset_name == "random-rerank":
|
||||
dataset_cls = RandomDatasetForReranking
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len
|
||||
if random_input_len is not None
|
||||
else getattr(args, "input_len", None)
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len
|
||||
if random_output_len is not None
|
||||
else getattr(args, "output_len", None)
|
||||
)
|
||||
sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
|
||||
sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
requests = filter_requests_for_dp(requests, args.data_parallel_size)
|
||||
return requests
|
||||
|
||||
|
||||
def filter_requests_for_dp(requests, data_parallel_size):
|
||||
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
|
||||
# works for external launcher mode. Should be cleaned up and deprecated
|
||||
# in the future with a better vLLM distributed process design.
|
||||
if data_parallel_size == 1:
|
||||
return requests
|
||||
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
data_parallel_rank = global_rank // (world_size // data_parallel_size)
|
||||
return [
|
||||
r
|
||||
for i, r in enumerate(requests)
|
||||
if i % data_parallel_size == data_parallel_rank
|
||||
]
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if (
|
||||
not args.dataset
|
||||
and not args.dataset_path
|
||||
and args.dataset_name not in {"prefix_repetition"}
|
||||
):
|
||||
print("When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = "random"
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
if args.input_len is None and random_input_len is None:
|
||||
raise ValueError(
|
||||
"Either --input-len or --random-input-len must be provided "
|
||||
"for a random dataset"
|
||||
)
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm-chat", (
|
||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
||||
)
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random',
|
||||
# 'random-mm', or 'random-rerank'
|
||||
if (
|
||||
args.dataset_name not in {"random", "random-mm", "random-rerank"}
|
||||
and args.random_range_ratio is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --random-batch-size: only used when dataset_name is 'random-rerank'
|
||||
if (
|
||||
args.dataset_name != "random-rerank"
|
||||
and getattr(args, "random_batch_size", None) is not None
|
||||
) and args.random_batch_size != 1:
|
||||
warnings.warn(
|
||||
"--random-batch-size will be ignored since \
|
||||
--dataset-name is not 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --no-reranker: only used when dataset_name is 'random-rerank'
|
||||
if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
|
||||
warnings.warn(
|
||||
"--no-reranker will be ignored since \
|
||||
--dataset-name is not 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'random-mm',
|
||||
# 'sonnet', or not set.
|
||||
if (
|
||||
args.dataset_name not in {"random", "random-mm", "sonnet", None}
|
||||
and args.prefix_len is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'random-mm', 'sonnet', or not set.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === Random Dataset Argument Conflict Detection ===
|
||||
# Check for conflicts between regular and random arguments when using
|
||||
# random datasets
|
||||
if args.dataset_name in {"random", "random-mm", "random-rerank"}:
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
|
||||
if args.input_len is not None and random_input_len is not None:
|
||||
warnings.warn(
|
||||
"Both --input-len and --random-input-len are specified. "
|
||||
"The random version (--random-input-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if args.output_len is not None and random_output_len is not None:
|
||||
warnings.warn(
|
||||
"Both --output-len and --random-output-len are specified. "
|
||||
"The random version (--random-output-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if args.prefix_len is not None and random_prefix_len is not None:
|
||||
warnings.warn(
|
||||
"Both --prefix-len and --random-prefix-len are specified. "
|
||||
"The random version (--random-prefix-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if (
|
||||
args.backend in {"hf", "mii"}
|
||||
and getattr(args, "quantization", None) is not None
|
||||
):
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
if args.data_parallel_size > 1 and (
|
||||
args.distributed_executor_backend != "external_launcher" or args.async_engine
|
||||
):
|
||||
# --data-parallel is not supported fully.
|
||||
# Old issue: https://github.com/vllm-project/vllm/issues/16222
|
||||
# Currently we only support data parallel with external launcher
|
||||
# mode (i.e., launch with toruchrun).
|
||||
raise ValueError(
|
||||
"Data parallel is only supported with external launcher mode "
|
||||
"with synchronous engine in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=[
|
||||
"sharegpt",
|
||||
"random",
|
||||
"sonnet",
|
||||
"burstgpt",
|
||||
"hf",
|
||||
"prefix_repetition",
|
||||
"random-mm",
|
||||
"random-rerank",
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-engine",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"
|
||||
),
|
||||
)
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before the random "
|
||||
"context in a request (default: 0).",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument(
|
||||
"--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
|
||||
# prefix repetition dataset
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-suffix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of suffix tokens per request, used only for prefix "
|
||||
"repetition dataset. Total input length is prefix_len + suffix_len.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-num-prefixes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefixes to generate, used only for prefix repetition "
|
||||
"dataset. Prompts per prefix is num_requests // num_prefixes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of output tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
|
||||
# (random, random-mm, random-rerank)
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
validate_args(args)
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
if (
|
||||
args.backend == "hf" or args.backend == "mii"
|
||||
) and args.tokenizer_mode == "auto":
|
||||
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
|
||||
# for hf and mii backends, we use hf tokenizer
|
||||
args.tokenizer_mode = "hf"
|
||||
tokenizer = get_tokenizer(
|
||||
args.tokenizer,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: list[RequestOutput] | None = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
)
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
if args.profile:
|
||||
raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
|
||||
elapsed_time = run_hf(
|
||||
requests,
|
||||
args.model,
|
||||
tokenizer,
|
||||
args.n,
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += (
|
||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
)
|
||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print(
|
||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
||||
)
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
||||
)
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
851
vllm/collect_env.py
Normal file
851
vllm/collect_env.py
Normal file
@@ -0,0 +1,851 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
|
||||
|
||||
import datetime
|
||||
import locale
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# Unlike the rest of the PyTorch this file must be python2 compliant.
|
||||
# This script outputs relevant system environment info
|
||||
# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
|
||||
from collections import namedtuple
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.envs import environment_variables
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
# System Environment Information
|
||||
SystemEnv = namedtuple(
|
||||
"SystemEnv",
|
||||
[
|
||||
"torch_version",
|
||||
"is_debug_build",
|
||||
"cuda_compiled_version",
|
||||
"gcc_version",
|
||||
"clang_version",
|
||||
"cmake_version",
|
||||
"os",
|
||||
"libc_version",
|
||||
"python_version",
|
||||
"python_platform",
|
||||
"is_cuda_available",
|
||||
"cuda_runtime_version",
|
||||
"cuda_module_loading",
|
||||
"nvidia_driver_version",
|
||||
"nvidia_gpu_models",
|
||||
"cudnn_version",
|
||||
"pip_version", # 'pip' or 'pip3'
|
||||
"pip_packages",
|
||||
"conda_packages",
|
||||
"hip_compiled_version",
|
||||
"hip_runtime_version",
|
||||
"miopen_runtime_version",
|
||||
"caching_allocator_config",
|
||||
"is_xnnpack_available",
|
||||
"cpu_info",
|
||||
"rocm_version", # vllm specific field
|
||||
"vllm_version", # vllm specific field
|
||||
"vllm_build_flags", # vllm specific field
|
||||
"gpu_topo", # vllm specific field
|
||||
"env_vars",
|
||||
],
|
||||
)
|
||||
|
||||
DEFAULT_CONDA_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"cudatoolkit",
|
||||
"soumith",
|
||||
"mkl",
|
||||
"magma",
|
||||
"triton",
|
||||
"optree",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
"helion",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"mypy",
|
||||
"flake8",
|
||||
"triton",
|
||||
"optree",
|
||||
"onnx",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
"helion",
|
||||
}
|
||||
|
||||
|
||||
def run(command):
|
||||
"""Return (return-code, stdout, stderr)."""
|
||||
shell = True if type(command) is str else False
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
|
||||
)
|
||||
raw_output, raw_err = p.communicate()
|
||||
rc = p.returncode
|
||||
if get_platform() == "win32":
|
||||
enc = "oem"
|
||||
else:
|
||||
enc = locale.getpreferredencoding()
|
||||
output = raw_output.decode(enc)
|
||||
if command == "nvidia-smi topo -m":
|
||||
# don't remove the leading whitespace of `nvidia-smi topo -m`
|
||||
# because they are meaningful
|
||||
output = output.rstrip()
|
||||
else:
|
||||
output = output.strip()
|
||||
err = raw_err.decode(enc)
|
||||
return rc, output, err.strip()
|
||||
|
||||
except FileNotFoundError:
|
||||
cmd_str = command if isinstance(command, str) else command[0]
|
||||
return 127, "", f"Command not found: {cmd_str}"
|
||||
|
||||
|
||||
def run_and_read_all(run_lambda, command):
|
||||
"""Run command using run_lambda; reads and returns entire output if rc is 0."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out
|
||||
|
||||
|
||||
def run_and_parse_first_match(run_lambda, command, regex):
|
||||
"""Run command using run_lambda, returns the first regex match if it exists."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
match = re.search(regex, out)
|
||||
if match is None:
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def get_conda_packages(run_lambda, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_CONDA_PATTERNS
|
||||
conda = os.environ.get("CONDA_EXE", "conda")
|
||||
out = run_and_read_all(run_lambda, [conda, "list"])
|
||||
if out is None:
|
||||
return out
|
||||
|
||||
return "\n".join(
|
||||
line
|
||||
for line in out.splitlines()
|
||||
if not line.startswith("#") and any(name in line for name in patterns)
|
||||
)
|
||||
|
||||
|
||||
def get_gcc_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
|
||||
|
||||
|
||||
def get_clang_version(run_lambda):
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "clang --version", r"clang version (.*)"
|
||||
)
|
||||
|
||||
|
||||
def get_cmake_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
|
||||
|
||||
|
||||
def get_nvidia_driver_version(run_lambda):
|
||||
if get_platform() == "darwin":
|
||||
cmd = "kextstat | grep -i cuda"
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]"
|
||||
)
|
||||
smi = get_nvidia_smi()
|
||||
return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
|
||||
|
||||
|
||||
def get_gpu_info(run_lambda):
|
||||
if get_platform() == "darwin" or (
|
||||
TORCH_AVAILABLE
|
||||
and hasattr(torch.version, "hip")
|
||||
and torch.version.hip is not None
|
||||
):
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
if torch.version.hip is not None:
|
||||
prop = torch.cuda.get_device_properties(0)
|
||||
if hasattr(prop, "gcnArchName"):
|
||||
gcnArch = " ({})".format(prop.gcnArchName)
|
||||
else:
|
||||
gcnArch = "NoGCNArchNameOnOldPyTorch"
|
||||
else:
|
||||
gcnArch = ""
|
||||
return torch.cuda.get_device_name(None) + gcnArch
|
||||
return None
|
||||
smi = get_nvidia_smi()
|
||||
uuid_regex = re.compile(r" \(UUID: .+?\)")
|
||||
rc, out, _ = run_lambda(smi + " -L")
|
||||
if rc != 0:
|
||||
return None
|
||||
# Anonymize GPUs by removing their UUID
|
||||
return re.sub(uuid_regex, "", out)
|
||||
|
||||
|
||||
def get_running_cuda_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
|
||||
|
||||
|
||||
def get_cudnn_version(run_lambda):
|
||||
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
|
||||
where_cmd = os.path.join(system_root, "System32", "where")
|
||||
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
|
||||
elif get_platform() == "darwin":
|
||||
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
|
||||
# https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
|
||||
# https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
|
||||
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
|
||||
cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
|
||||
else:
|
||||
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
|
||||
rc, out, _ = run_lambda(cudnn_cmd)
|
||||
# find will return 1 if there are permission errors or if not found
|
||||
if len(out) == 0 or (rc != 1 and rc != 0):
|
||||
l = os.environ.get("CUDNN_LIBRARY")
|
||||
if l is not None and os.path.isfile(l):
|
||||
return os.path.realpath(l)
|
||||
return None
|
||||
files_set = set()
|
||||
for fn in out.split("\n"):
|
||||
fn = os.path.realpath(fn) # eliminate symbolic links
|
||||
if os.path.isfile(fn):
|
||||
files_set.add(fn)
|
||||
if not files_set:
|
||||
return None
|
||||
# Alphabetize the result because the order is non-deterministic otherwise
|
||||
files = sorted(files_set)
|
||||
if len(files) == 1:
|
||||
return files[0]
|
||||
result = "\n".join(files)
|
||||
return "Probably one of the following:\n{}".format(result)
|
||||
|
||||
|
||||
def get_nvidia_smi():
|
||||
# Note: nvidia-smi is currently available only on Windows and Linux
|
||||
smi = "nvidia-smi"
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
|
||||
legacy_path = os.path.join(
|
||||
program_files_root, "NVIDIA Corporation", "NVSMI", smi
|
||||
)
|
||||
new_path = os.path.join(system_root, "System32", smi)
|
||||
smis = [new_path, legacy_path]
|
||||
for candidate_smi in smis:
|
||||
if os.path.exists(candidate_smi):
|
||||
smi = '"{}"'.format(candidate_smi)
|
||||
break
|
||||
return smi
|
||||
|
||||
|
||||
def get_rocm_version(run_lambda):
|
||||
"""Returns the ROCm version if available, otherwise 'N/A'."""
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "hipcc --version", r"HIP version: (\S+)"
|
||||
)
|
||||
|
||||
|
||||
def get_vllm_version():
|
||||
from vllm import __version__, __version_tuple__
|
||||
|
||||
if __version__ == "dev":
|
||||
return "N/A (dev)"
|
||||
version_str = __version_tuple__[-1]
|
||||
if isinstance(version_str, str) and version_str.startswith("g"):
|
||||
# it's a dev build
|
||||
if "." in version_str:
|
||||
# it's a dev build containing local changes
|
||||
git_sha = version_str.split(".")[0][1:]
|
||||
date = version_str.split(".")[-1][1:]
|
||||
return f"{__version__} (git sha: {git_sha}, date: {date})"
|
||||
else:
|
||||
# it's a dev build without local changes
|
||||
git_sha = version_str[1:] # type: ignore
|
||||
return f"{__version__} (git sha: {git_sha})"
|
||||
return __version__
|
||||
|
||||
|
||||
def summarize_vllm_build_flags():
|
||||
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
||||
return "CUDA Archs: {}; ROCm: {}".format(
|
||||
os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"),
|
||||
"Enabled" if os.environ.get("ROCM_HOME") else "Disabled",
|
||||
)
|
||||
|
||||
|
||||
def get_gpu_topo(run_lambda):
|
||||
output = None
|
||||
|
||||
if get_platform() == "linux":
|
||||
output = run_and_read_all(run_lambda, "nvidia-smi topo -m")
|
||||
if output is None:
|
||||
output = run_and_read_all(run_lambda, "rocm-smi --showtopo")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# example outputs of CPU infos
|
||||
# * linux
|
||||
# Architecture: x86_64
|
||||
# CPU op-mode(s): 32-bit, 64-bit
|
||||
# Address sizes: 46 bits physical, 48 bits virtual
|
||||
# Byte Order: Little Endian
|
||||
# CPU(s): 128
|
||||
# On-line CPU(s) list: 0-127
|
||||
# Vendor ID: GenuineIntel
|
||||
# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# CPU family: 6
|
||||
# Model: 106
|
||||
# Thread(s) per core: 2
|
||||
# Core(s) per socket: 32
|
||||
# Socket(s): 2
|
||||
# Stepping: 6
|
||||
# BogoMIPS: 5799.78
|
||||
# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
|
||||
# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
|
||||
# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
|
||||
# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
|
||||
# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
|
||||
# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
|
||||
# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
|
||||
# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
|
||||
# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
|
||||
# Virtualization features:
|
||||
# Hypervisor vendor: KVM
|
||||
# Virtualization type: full
|
||||
# Caches (sum of all):
|
||||
# L1d: 3 MiB (64 instances)
|
||||
# L1i: 2 MiB (64 instances)
|
||||
# L2: 80 MiB (64 instances)
|
||||
# L3: 108 MiB (2 instances)
|
||||
# NUMA:
|
||||
# NUMA node(s): 2
|
||||
# NUMA node0 CPU(s): 0-31,64-95
|
||||
# NUMA node1 CPU(s): 32-63,96-127
|
||||
# Vulnerabilities:
|
||||
# Itlb multihit: Not affected
|
||||
# L1tf: Not affected
|
||||
# Mds: Not affected
|
||||
# Meltdown: Not affected
|
||||
# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
|
||||
# Retbleed: Not affected
|
||||
# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
|
||||
# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
|
||||
# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
|
||||
# Srbds: Not affected
|
||||
# Tsx async abort: Not affected
|
||||
# * win32
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU0
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
#
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU1
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
|
||||
|
||||
def get_cpu_info(run_lambda):
|
||||
rc, out, err = 0, "", ""
|
||||
if get_platform() == "linux":
|
||||
rc, out, err = run_lambda("lscpu")
|
||||
elif get_platform() == "win32":
|
||||
rc, out, err = run_lambda(
|
||||
"wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
|
||||
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE"
|
||||
)
|
||||
elif get_platform() == "darwin":
|
||||
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
||||
cpu_info = "None"
|
||||
if rc == 0:
|
||||
cpu_info = out
|
||||
else:
|
||||
cpu_info = err
|
||||
return cpu_info
|
||||
|
||||
|
||||
def get_platform():
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif sys.platform.startswith("win32"):
|
||||
return "win32"
|
||||
elif sys.platform.startswith("cygwin"):
|
||||
return "cygwin"
|
||||
elif sys.platform.startswith("darwin"):
|
||||
return "darwin"
|
||||
else:
|
||||
return sys.platform
|
||||
|
||||
|
||||
def get_mac_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
|
||||
|
||||
|
||||
def get_windows_version(run_lambda):
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic")
|
||||
findstr_cmd = os.path.join(system_root, "System32", "findstr")
|
||||
return run_and_read_all(
|
||||
run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd)
|
||||
)
|
||||
|
||||
|
||||
def get_lsb_version(run_lambda):
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "lsb_release -a", r"Description:\t(.*)"
|
||||
)
|
||||
|
||||
|
||||
def check_release_file(run_lambda):
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"'
|
||||
)
|
||||
|
||||
|
||||
def get_os(run_lambda):
|
||||
from platform import machine
|
||||
|
||||
platform = get_platform()
|
||||
|
||||
if platform == "win32" or platform == "cygwin":
|
||||
return get_windows_version(run_lambda)
|
||||
|
||||
if platform == "darwin":
|
||||
version = get_mac_version(run_lambda)
|
||||
if version is None:
|
||||
return None
|
||||
return "macOS {} ({})".format(version, machine())
|
||||
|
||||
if platform == "linux":
|
||||
# Ubuntu/Debian based
|
||||
desc = get_lsb_version(run_lambda)
|
||||
if desc is not None:
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
# Try reading /etc/*-release
|
||||
desc = check_release_file(run_lambda)
|
||||
if desc is not None:
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
return "{} ({})".format(platform, machine())
|
||||
|
||||
# Unknown platform
|
||||
return platform
|
||||
|
||||
|
||||
def get_python_platform():
|
||||
import platform
|
||||
|
||||
return platform.platform()
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
import platform
|
||||
|
||||
if get_platform() != "linux":
|
||||
return "N/A"
|
||||
return "-".join(platform.libc_ver())
|
||||
|
||||
|
||||
def is_uv_venv():
|
||||
if os.environ.get("UV"):
|
||||
return True
|
||||
pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg")
|
||||
if os.path.exists(pyvenv_cfg_path):
|
||||
with open(pyvenv_cfg_path, "r") as f:
|
||||
return any(line.startswith("uv = ") for line in f)
|
||||
return False
|
||||
|
||||
|
||||
def get_pip_packages(run_lambda, patterns=None):
|
||||
"""Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_PIP_PATTERNS
|
||||
|
||||
def run_with_pip():
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
pip_spec = importlib.util.find_spec("pip")
|
||||
pip_available = pip_spec is not None
|
||||
except ImportError:
|
||||
pip_available = False
|
||||
|
||||
if pip_available:
|
||||
cmd = [sys.executable, "-mpip", "list", "--format=freeze"]
|
||||
elif is_uv_venv():
|
||||
print("uv is set")
|
||||
cmd = ["uv", "pip", "list", "--format=freeze"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not collect pip list output (pip or uv module not available)"
|
||||
)
|
||||
|
||||
out = run_and_read_all(run_lambda, cmd)
|
||||
return "\n".join(
|
||||
line for line in out.splitlines() if any(name in line for name in patterns)
|
||||
)
|
||||
|
||||
pip_version = "pip3" if sys.version[0] == "3" else "pip"
|
||||
out = run_with_pip()
|
||||
return pip_version, out
|
||||
|
||||
|
||||
def get_cachingallocator_config():
|
||||
ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
||||
return ca_config
|
||||
|
||||
|
||||
def get_cuda_module_loading_config():
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
config = os.environ.get("CUDA_MODULE_LOADING", "")
|
||||
return config
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def is_xnnpack_available():
|
||||
if TORCH_AVAILABLE:
|
||||
import torch.backends.xnnpack
|
||||
|
||||
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_env_vars():
|
||||
env_vars = ""
|
||||
secret_terms = ("secret", "token", "api", "access", "password")
|
||||
report_prefix = (
|
||||
"TORCH",
|
||||
"NCCL",
|
||||
"PYTORCH",
|
||||
"CUDA",
|
||||
"CUBLAS",
|
||||
"CUDNN",
|
||||
"OMP_",
|
||||
"MKL_",
|
||||
"NVIDIA",
|
||||
)
|
||||
for k, v in os.environ.items():
|
||||
if any(term in k.lower() for term in secret_terms):
|
||||
continue
|
||||
if k in environment_variables:
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
if k.startswith(report_prefix):
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def get_env_info():
|
||||
run_lambda = run
|
||||
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
version_str = torch.__version__
|
||||
debug_mode_str = str(torch.version.debug)
|
||||
cuda_available_str = str(torch.cuda.is_available())
|
||||
cuda_version_str = torch.version.cuda
|
||||
if (
|
||||
not hasattr(torch.version, "hip") or torch.version.hip is None
|
||||
): # cuda version
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
|
||||
else: # HIP version
|
||||
|
||||
def get_version_or_na(cfg, prefix):
|
||||
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
|
||||
return _lst[0] if _lst else "N/A"
|
||||
|
||||
cfg = torch._C._show_config().split("\n")
|
||||
hip_runtime_version = get_version_or_na(cfg, "HIP Runtime")
|
||||
miopen_runtime_version = get_version_or_na(cfg, "MIOpen")
|
||||
cuda_version_str = "N/A"
|
||||
hip_compiled_version = torch.version.hip
|
||||
else:
|
||||
version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A"
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
|
||||
|
||||
sys_version = sys.version.replace("\n", " ")
|
||||
|
||||
conda_packages = get_conda_packages(run_lambda)
|
||||
|
||||
rocm_version = get_rocm_version(run_lambda)
|
||||
vllm_version = get_vllm_version()
|
||||
vllm_build_flags = summarize_vllm_build_flags()
|
||||
gpu_topo = get_gpu_topo(run_lambda)
|
||||
|
||||
return SystemEnv(
|
||||
torch_version=version_str,
|
||||
is_debug_build=debug_mode_str,
|
||||
python_version="{} ({}-bit runtime)".format(
|
||||
sys_version, sys.maxsize.bit_length() + 1
|
||||
),
|
||||
python_platform=get_python_platform(),
|
||||
is_cuda_available=cuda_available_str,
|
||||
cuda_compiled_version=cuda_version_str,
|
||||
cuda_runtime_version=get_running_cuda_version(run_lambda),
|
||||
cuda_module_loading=get_cuda_module_loading_config(),
|
||||
nvidia_gpu_models=get_gpu_info(run_lambda),
|
||||
nvidia_driver_version=get_nvidia_driver_version(run_lambda),
|
||||
cudnn_version=get_cudnn_version(run_lambda),
|
||||
hip_compiled_version=hip_compiled_version,
|
||||
hip_runtime_version=hip_runtime_version,
|
||||
miopen_runtime_version=miopen_runtime_version,
|
||||
pip_version=pip_version,
|
||||
pip_packages=pip_list_output,
|
||||
conda_packages=conda_packages,
|
||||
os=get_os(run_lambda),
|
||||
libc_version=get_libc_version(),
|
||||
gcc_version=get_gcc_version(run_lambda),
|
||||
clang_version=get_clang_version(run_lambda),
|
||||
cmake_version=get_cmake_version(run_lambda),
|
||||
caching_allocator_config=get_cachingallocator_config(),
|
||||
is_xnnpack_available=is_xnnpack_available(),
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
rocm_version=rocm_version,
|
||||
vllm_version=vllm_version,
|
||||
vllm_build_flags=vllm_build_flags,
|
||||
gpu_topo=gpu_topo,
|
||||
env_vars=get_env_vars(),
|
||||
)
|
||||
|
||||
|
||||
env_info_fmt = """
|
||||
==============================
|
||||
System Info
|
||||
==============================
|
||||
OS : {os}
|
||||
GCC version : {gcc_version}
|
||||
Clang version : {clang_version}
|
||||
CMake version : {cmake_version}
|
||||
Libc version : {libc_version}
|
||||
|
||||
==============================
|
||||
PyTorch Info
|
||||
==============================
|
||||
PyTorch version : {torch_version}
|
||||
Is debug build : {is_debug_build}
|
||||
CUDA used to build PyTorch : {cuda_compiled_version}
|
||||
ROCM used to build PyTorch : {hip_compiled_version}
|
||||
|
||||
==============================
|
||||
Python Environment
|
||||
==============================
|
||||
Python version : {python_version}
|
||||
Python platform : {python_platform}
|
||||
|
||||
==============================
|
||||
CUDA / GPU Info
|
||||
==============================
|
||||
Is CUDA available : {is_cuda_available}
|
||||
CUDA runtime version : {cuda_runtime_version}
|
||||
CUDA_MODULE_LOADING set to : {cuda_module_loading}
|
||||
GPU models and configuration : {nvidia_gpu_models}
|
||||
Nvidia driver version : {nvidia_driver_version}
|
||||
cuDNN version : {cudnn_version}
|
||||
HIP runtime version : {hip_runtime_version}
|
||||
MIOpen runtime version : {miopen_runtime_version}
|
||||
Is XNNPACK available : {is_xnnpack_available}
|
||||
|
||||
==============================
|
||||
CPU Info
|
||||
==============================
|
||||
{cpu_info}
|
||||
|
||||
==============================
|
||||
Versions of relevant libraries
|
||||
==============================
|
||||
{pip_packages}
|
||||
{conda_packages}
|
||||
""".strip()
|
||||
|
||||
# both the above code and the following code use `strip()` to
|
||||
# remove leading/trailing whitespaces, so we need to add a newline
|
||||
# in between to separate the two sections
|
||||
env_info_fmt += "\n\n"
|
||||
|
||||
env_info_fmt += """
|
||||
==============================
|
||||
vLLM Info
|
||||
==============================
|
||||
ROCM Version : {rocm_version}
|
||||
vLLM Version : {vllm_version}
|
||||
vLLM Build Flags:
|
||||
{vllm_build_flags}
|
||||
GPU Topology:
|
||||
{gpu_topo}
|
||||
|
||||
==============================
|
||||
Environment Variables
|
||||
==============================
|
||||
{env_vars}
|
||||
""".strip()
|
||||
|
||||
|
||||
def pretty_str(envinfo):
|
||||
def replace_nones(dct, replacement="Could not collect"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is not None:
|
||||
continue
|
||||
dct[key] = replacement
|
||||
return dct
|
||||
|
||||
def replace_bools(dct, true="Yes", false="No"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is True:
|
||||
dct[key] = true
|
||||
elif dct[key] is False:
|
||||
dct[key] = false
|
||||
return dct
|
||||
|
||||
def prepend(text, tag="[prepend]"):
|
||||
lines = text.split("\n")
|
||||
updated_lines = [tag + line for line in lines]
|
||||
return "\n".join(updated_lines)
|
||||
|
||||
def replace_if_empty(text, replacement="No relevant packages"):
|
||||
if text is not None and len(text) == 0:
|
||||
return replacement
|
||||
return text
|
||||
|
||||
def maybe_start_on_next_line(string):
|
||||
# If `string` is multiline, prepend a \n to it.
|
||||
if string is not None and len(string.split("\n")) > 1:
|
||||
return "\n{}\n".format(string)
|
||||
return string
|
||||
|
||||
mutable_dict = envinfo._asdict()
|
||||
|
||||
# If nvidia_gpu_models is multiline, start on the next line
|
||||
mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(
|
||||
envinfo.nvidia_gpu_models
|
||||
)
|
||||
|
||||
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
|
||||
dynamic_cuda_fields = [
|
||||
"cuda_runtime_version",
|
||||
"nvidia_gpu_models",
|
||||
"nvidia_driver_version",
|
||||
]
|
||||
all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
|
||||
all_dynamic_cuda_fields_missing = all(
|
||||
mutable_dict[field] is None for field in dynamic_cuda_fields
|
||||
)
|
||||
if (
|
||||
TORCH_AVAILABLE
|
||||
and not torch.cuda.is_available()
|
||||
and all_dynamic_cuda_fields_missing
|
||||
):
|
||||
for field in all_cuda_fields:
|
||||
mutable_dict[field] = "No CUDA"
|
||||
if envinfo.cuda_compiled_version is None:
|
||||
mutable_dict["cuda_compiled_version"] = "None"
|
||||
|
||||
# Replace True with Yes, False with No
|
||||
mutable_dict = replace_bools(mutable_dict)
|
||||
|
||||
# Replace all None objects with 'Could not collect'
|
||||
mutable_dict = replace_nones(mutable_dict)
|
||||
|
||||
# If either of these are '', replace with 'No relevant packages'
|
||||
mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
|
||||
mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
|
||||
|
||||
# Tag conda and pip packages with a prefix
|
||||
# If they were previously None, they'll show up as ie '[conda] Could not collect'
|
||||
if mutable_dict["pip_packages"]:
|
||||
mutable_dict["pip_packages"] = prepend(
|
||||
mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)
|
||||
)
|
||||
if mutable_dict["conda_packages"]:
|
||||
mutable_dict["conda_packages"] = prepend(
|
||||
mutable_dict["conda_packages"], "[conda] "
|
||||
)
|
||||
mutable_dict["cpu_info"] = envinfo.cpu_info
|
||||
return env_info_fmt.format(**mutable_dict)
|
||||
|
||||
|
||||
def get_pretty_env_info():
|
||||
return pretty_str(get_env_info())
|
||||
|
||||
|
||||
def main():
|
||||
print("Collecting environment information...")
|
||||
output = get_pretty_env_info()
|
||||
print(output)
|
||||
|
||||
if (
|
||||
TORCH_AVAILABLE
|
||||
and hasattr(torch, "utils")
|
||||
and hasattr(torch.utils, "_crash_handler")
|
||||
):
|
||||
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
|
||||
if sys.platform == "linux" and os.path.exists(minidump_dir):
|
||||
dumps = [
|
||||
os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)
|
||||
]
|
||||
latest = max(dumps, key=os.path.getctime)
|
||||
ctime = os.path.getctime(latest)
|
||||
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
msg = (
|
||||
"\n*** Detected a minidump at {} created on {}, ".format(
|
||||
latest, creation_time
|
||||
)
|
||||
+ "if this is related to your bug please include it when you file a report ***"
|
||||
)
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
vllm/compilation/__init__.py
Normal file
0
vllm/compilation/__init__.py
Normal file
1131
vllm/compilation/backends.py
Normal file
1131
vllm/compilation/backends.py
Normal file
File diff suppressed because it is too large
Load Diff
57
vllm/compilation/base_static_graph.py
Normal file
57
vllm/compilation/base_static_graph.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Protocol
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
|
||||
|
||||
class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||
to be captured as a static graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
|
||||
Args:
|
||||
runnable (Callable): The callable to be wrapped and captured.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
runtime_mode (CUDAGraphMode): The style of the static
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes the wrapped callable.
|
||||
|
||||
If the current runtime mode in the ForwardContext matches the runtime
|
||||
mode of this instance, it replays the CUDAGraph or captures it using
|
||||
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||
original callable directly.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
callable.
|
||||
**kwargs: Keyword arguments to be passed into the callable.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed callable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
516
vllm/compilation/caching.py
Normal file
516
vllm/compilation/caching.py
Normal file
@@ -0,0 +1,516 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.compiler_interface import get_inductor_factors
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
try:
|
||||
from torch._dynamo.aot_compile import SerializableCallable
|
||||
except ImportError:
|
||||
SerializableCallable = object
|
||||
|
||||
assert isinstance(SerializableCallable, type)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StandaloneCompiledArtifacts:
|
||||
"""Storage for standalone compiled artifacts with content-based deduplication.
|
||||
|
||||
Deduplication works via a two-level indirection:
|
||||
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
|
||||
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
|
||||
|
||||
When inserting, we compute the SHA256 hash of the bytes. If the hash
|
||||
already exists in `submodule_bytes_store`, we reuse the existing entry
|
||||
rather than storing duplicate bytes. This is common because submodules
|
||||
often compile to identical artifacts (e.g., identical transformer layers
|
||||
split on attn)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# dict from submodule name to byte hash
|
||||
self.submodule_bytes: dict[str, str] = {}
|
||||
# dict from byte hash to bytes
|
||||
self.submodule_bytes_store: dict[str, bytes] = {}
|
||||
# dict from byte hash to loaded module
|
||||
self.loaded_submodule_store: dict[str, Any] = {}
|
||||
|
||||
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(entry)
|
||||
hex_digest = hasher.hexdigest()
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
|
||||
if hex_digest not in self.submodule_bytes_store:
|
||||
self.submodule_bytes_store[hex_digest] = entry
|
||||
logger.debug(
|
||||
"inserting new artifact for submod %s with shape %s "
|
||||
"(%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"reusing existing cache artifact for submod %s "
|
||||
"with shape %s (%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
|
||||
def get(self, submod_name: str, shape: str) -> bytes:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.submodule_bytes_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def get_loaded(self, submod_name: str, shape: str) -> Any:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.loaded_submodule_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def size_bytes(self) -> int:
|
||||
return sum(len(entry) for entry in self.submodule_bytes_store.values())
|
||||
|
||||
def num_artifacts(self) -> int:
|
||||
return len(self.submodule_bytes_store)
|
||||
|
||||
def num_entries(self) -> int:
|
||||
return len(self.submodule_bytes)
|
||||
|
||||
def submodule_names(self) -> list[str]:
|
||||
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
|
||||
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
|
||||
return list(dict.fromkeys(names))
|
||||
|
||||
def load_all(self) -> None:
|
||||
import concurrent.futures
|
||||
|
||||
# check already loaded
|
||||
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
|
||||
return
|
||||
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
|
||||
entry = pickle.loads(entry_bytes)
|
||||
return AOTCompiledArtifact.deserialize(entry)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
entries = list(self.submodule_bytes_store.values())
|
||||
loaded_entries = list(executor.map(_load_entry, entries))
|
||||
|
||||
for i, k in enumerate(self.submodule_bytes_store.keys()):
|
||||
self.loaded_submodule_store[k] = loaded_entries[i]
|
||||
|
||||
logger.debug("loaded all %s submodules", self.num_artifacts())
|
||||
|
||||
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
|
||||
return {
|
||||
"submodule_bytes": self.submodule_bytes,
|
||||
"submodule_bytes_store": self.submodule_bytes_store,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
|
||||
self.submodule_bytes = state["submodule_bytes"]
|
||||
self.submodule_bytes_store = state["submodule_bytes_store"]
|
||||
self.loaded_submodule_store = {}
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
inputs to the compiled function and return the result.
|
||||
It also implements a serialization interface to support PyTorch's precompile
|
||||
with custom backend, so that we can save and load the compiled function on
|
||||
disk. There's no need to wrap around the compiled function if we don't want
|
||||
to serialize them in particular cases.
|
||||
Right now serialization for the custom backend is done via
|
||||
serializing the Dynamo fx graph plus example inputs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_module: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[Any],
|
||||
prefix: str,
|
||||
optimized_call: Callable[..., Any],
|
||||
is_encoder: bool = False,
|
||||
vllm_backend: Any | None = None,
|
||||
sym_tensor_indices: list[int] | None = None,
|
||||
) -> None:
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
self.graph_module = graph_module
|
||||
self.example_inputs = example_inputs
|
||||
self.prefix = prefix
|
||||
self.optimized_call = optimized_call
|
||||
self.is_encoder = is_encoder
|
||||
self.shape_env = None
|
||||
self.vllm_backend = vllm_backend
|
||||
self.sym_tensor_indices = sym_tensor_indices
|
||||
sym_input = next(
|
||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||
)
|
||||
if sym_input is not None:
|
||||
self.shape_env = sym_input.node.shape_env
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.optimized_call(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, compiled_fn: "VllmSerializableFunction"
|
||||
) -> bytes:
|
||||
import sympy
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler, Options
|
||||
|
||||
state = compiled_fn.__dict__.copy()
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
state.pop("vllm_backend", None)
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
for name, submod in state["graph_module"].named_children():
|
||||
if hasattr(submod, "graph"):
|
||||
for node in submod.graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
|
||||
graph_reducer_override = GraphPickler.reducer_override
|
||||
|
||||
def _graph_reducer_override(
|
||||
self: GraphPickler, obj: Any
|
||||
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, sympy.Function)
|
||||
and hasattr(obj, "_torch_unpickler")
|
||||
):
|
||||
return obj._torch_unpickler, (obj._torch_handler_name,)
|
||||
if isinstance(obj, FakeTensorMode):
|
||||
return type(None), ()
|
||||
return graph_reducer_override(self, obj)
|
||||
|
||||
if state.get("sym_tensor_indices"):
|
||||
# put tensor inputs on meta device since their data
|
||||
# isn't needed, yet we need the meta for make_copy_and_call
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
else:
|
||||
# mask off all tensor inputs since they are large and not needed.
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||
|
||||
if compiled_fn.vllm_backend:
|
||||
(
|
||||
standalone_compile_artifacts,
|
||||
sym_shape_indices_map,
|
||||
returns_tuple_map,
|
||||
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
|
||||
state["standalone_compile_artifacts"] = standalone_compile_artifacts
|
||||
state["sym_shape_indices_map"] = sym_shape_indices_map
|
||||
state["returns_tuple_map"] = returns_tuple_map
|
||||
return pickle.dumps(state)
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
|
||||
from torch._guards import TracingContext, tracing
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
state["graph_module"].recompile()
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
|
||||
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
|
||||
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
|
||||
returns_tuple_map = state.pop("returns_tuple_map", {})
|
||||
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
assert standalone_compile_artifacts is not None
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
num_submods = len(submod_names)
|
||||
num_artifacts = standalone_compile_artifacts.num_artifacts()
|
||||
|
||||
logger.info(
|
||||
"reconstructing serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
standalone_compile_artifacts=standalone_compile_artifacts,
|
||||
vllm_config=get_current_vllm_config(),
|
||||
sym_shape_indices_map=sym_shape_indices_map,
|
||||
returns_tuple_map=returns_tuple_map,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"reconstructed serializable fn from standalone compile artifacts"
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
# Fall back to standard VllmBackend
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
vllm_backend: VllmBackend = VllmBackend(
|
||||
get_current_vllm_config(), state["prefix"], is_encoder
|
||||
)
|
||||
|
||||
def optimized_call(*example_inputs: Any) -> Any:
|
||||
"""
|
||||
On the first run of the optimized call, we rerun the compiler
|
||||
backend which should result in a cache hit. After the backend
|
||||
call returns, we just do a one-time replacement of the optimized
|
||||
call with the compiled function, so that subsequent calls are on
|
||||
the AOT compiled path.
|
||||
"""
|
||||
compile_inputs = [
|
||||
inp if inp is not None else example_inputs[i]
|
||||
for i, inp in enumerate(fn.example_inputs)
|
||||
]
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
fn.optimized_call = vllm_backend(
|
||||
state["graph_module"], compile_inputs
|
||||
).optimized_call
|
||||
return fn.optimized_call(*example_inputs)
|
||||
|
||||
fn = cls(**state, optimized_call=optimized_call)
|
||||
return fn
|
||||
|
||||
@property
|
||||
def co_name(self) -> Literal["VllmSerializableFunction"]:
|
||||
"""
|
||||
Used for depyf debugging.
|
||||
"""
|
||||
return "VllmSerializableFunction"
|
||||
|
||||
|
||||
def reconstruct_serializable_fn_from_mega_artifact(
|
||||
state: dict[str, Any],
|
||||
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
|
||||
vllm_config: VllmConfig,
|
||||
sym_shape_indices_map: dict[str, list[int]],
|
||||
returns_tuple_map: dict[str, bool],
|
||||
) -> "VllmSerializableFunction":
|
||||
"""Construct a VllmSerializableFunction from cached inductor artifacts.
|
||||
|
||||
This function reconstructs a callable model from pre-compiled inductor
|
||||
artifacts without re-running the compilation. It:
|
||||
1. Loads all cached artifacts
|
||||
2. Builds compiled callables for each submodule/shape
|
||||
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
|
||||
4. Wraps with cudagraph if needed
|
||||
5. Returns the final VllmSerializableFunction
|
||||
|
||||
Note: This function shares similar logic with PiecewiseCompileInterpreter
|
||||
in backends.py. Both create PiecewiseBackend instances and wrap them with
|
||||
cudagraph. The key difference is:
|
||||
- this function: PiecewiseBackend receives pre-compiled runnables
|
||||
(compiled_runnables is set, graph is None)
|
||||
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
|
||||
to compile (graph is set, compiled_runnables is None)
|
||||
|
||||
If modifying the backend creation/wrapping logic, consider updating both.
|
||||
|
||||
Args:
|
||||
state: Deserialized state dict containing graph_module, example_inputs,
|
||||
prefix, sym_tensor_indices, is_encoder, etc.
|
||||
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
|
||||
pre-compiled artifacts for each submodule/shape combination.
|
||||
vllm_config: The vLLM configuration.
|
||||
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
|
||||
returns_tuple_map: Mapping from submod_name to returns_tuple.
|
||||
|
||||
Returns:
|
||||
A VllmSerializableFunction that can be called directly.
|
||||
"""
|
||||
from vllm.compilation.backends import (
|
||||
VllmBackend,
|
||||
make_copy_and_call,
|
||||
wrap_with_cudagraph_if_needed,
|
||||
)
|
||||
from vllm.compilation.piecewise_backend import PiecewiseBackend
|
||||
|
||||
prefix = state["prefix"]
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
split_gm = state["graph_module"]
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
standalone_compile_artifacts.load_all()
|
||||
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
|
||||
|
||||
for cache_key in standalone_compile_artifacts.submodule_bytes:
|
||||
submod_name, shape_str = cache_key.rsplit("_", 1)
|
||||
compiled_callables.setdefault(submod_name, {})[shape_str] = (
|
||||
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
|
||||
)
|
||||
|
||||
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
|
||||
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
|
||||
os.makedirs(dummy_cache_dir, exist_ok=True)
|
||||
vllm_backend.compiler_manager.initialize_cache(
|
||||
cache_dir=dummy_cache_dir,
|
||||
disable_cache=True,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
# spot check that cached submodules exist in the graph structure
|
||||
graph_children = {name for name, _ in split_gm.named_children()}
|
||||
missing = set(submod_names) - graph_children
|
||||
assert not missing, (
|
||||
f"artifacts reference submodules not in graph: {missing}. "
|
||||
f"graph has: {sorted(graph_children)}"
|
||||
)
|
||||
|
||||
for i, submod_name in enumerate(submod_names):
|
||||
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
|
||||
|
||||
sym_shape_indices = sym_shape_indices_map[submod_name]
|
||||
returns_tuple = returns_tuple_map[submod_name]
|
||||
runnables = compiled_callables[submod_name]
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
graph=None, # not needed for cached artifacts
|
||||
vllm_config=vllm_config,
|
||||
piecewise_compile_index=i,
|
||||
total_piecewise_compiles=len(submod_names),
|
||||
sym_shape_indices=sym_shape_indices,
|
||||
vllm_backend=vllm_backend,
|
||||
returns_tuple=returns_tuple,
|
||||
compiled_runnables=runnables,
|
||||
)
|
||||
|
||||
is_first = i == 0
|
||||
is_last = i == len(submod_names) - 1
|
||||
wrapped_backend = wrap_with_cudagraph_if_needed(
|
||||
piecewise_backend,
|
||||
vllm_config,
|
||||
compilation_config,
|
||||
is_first,
|
||||
is_last,
|
||||
)
|
||||
|
||||
split_gm.__dict__[submod_name] = wrapped_backend
|
||||
logger.debug(
|
||||
"Replaced submodule %s with piecewise backend from cache",
|
||||
submod_name,
|
||||
)
|
||||
|
||||
if compilation_config.cudagraph_copy_inputs:
|
||||
sym_tensor_indices = state["sym_tensor_indices"]
|
||||
input_buffers = [
|
||||
torch.empty_like(
|
||||
state["example_inputs"][idx], device=vllm_config.device_config.device
|
||||
)
|
||||
for idx in sym_tensor_indices
|
||||
]
|
||||
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
|
||||
else:
|
||||
optimized_call = split_gm
|
||||
|
||||
fn = VllmSerializableFunction(
|
||||
**state,
|
||||
optimized_call=optimized_call,
|
||||
vllm_backend=None,
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
env_hash = hash_factors(envs.compile_factors())
|
||||
factors.append(env_hash)
|
||||
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
|
||||
# 2. inductor factors if applicable
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
factors.extend(get_inductor_factors())
|
||||
|
||||
return factors
|
||||
|
||||
|
||||
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
|
||||
hash_content = []
|
||||
for filepath, content in items:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(content)
|
||||
result: str = safe_hash(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
return result
|
||||
|
||||
|
||||
def _compute_code_hash(files: set[str]) -> str:
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
|
||||
)
|
||||
file_contents = {}
|
||||
for filepath in files:
|
||||
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
|
||||
if not os.path.isfile(filepath):
|
||||
file_contents[filepath] = ""
|
||||
else:
|
||||
with open(filepath) as f:
|
||||
file_contents[filepath] = f.read()
|
||||
return _compute_code_hash_with_content(file_contents)
|
||||
660
vllm/compilation/compiler_interface.py
Normal file
660
vllm/compilation/compiler_interface.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the vLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||
to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a range. The `compile_range` specifies the range of the inputs,
|
||||
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
|
||||
or a range [5, 8].
|
||||
Right now we only support one variable in ranges for all inputs,
|
||||
which is the batchsize (number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
def is_compile_cache_enabled(
|
||||
vllm_additional_inductor_config: dict[str, Any],
|
||||
) -> bool:
|
||||
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
|
||||
"force_disable_caches", False
|
||||
)
|
||||
|
||||
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
|
||||
# with torch.compiler.config.force_disable_caches when minimum PyTorch
|
||||
# version reaches 2.10
|
||||
return (
|
||||
not envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
and not torch._inductor.config.force_disable_caches
|
||||
and not vllm_inductor_config_disable_cache
|
||||
)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
if compile_range.is_single_size():
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_graph"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
supports_aot = is_torch_equal_or_newer("2.10.0")
|
||||
|
||||
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
logger.error(
|
||||
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
|
||||
"is enabled but PyTorch version does not support 'aot' "
|
||||
"parameter in standalone_compile. This requires PyTorch "
|
||||
"2.10.0+. Falling back to non-AOT mode."
|
||||
)
|
||||
|
||||
compile_kwargs = {
|
||||
"dynamic_shapes": dynamic_shapes,
|
||||
"options": {
|
||||
"config_patches": current_config,
|
||||
},
|
||||
}
|
||||
|
||||
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
|
||||
# only add 'aot' parameter if both supported and enabled...
|
||||
# this will set bundled_autograd_cache
|
||||
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
|
||||
if use_aot:
|
||||
compile_kwargs["aot"] = True # type: ignore[assignment]
|
||||
|
||||
# Inductor's pre-grad passes don't do anything for vLLM.
|
||||
# The pre-grad passes get run even on cache-hit and negatively impact
|
||||
# vllm cold compile times by O(1s)
|
||||
# Can remove this after the following issue gets fixed
|
||||
# https://github.com/pytorch/pytorch/issues/174502
|
||||
if envs.VLLM_ENABLE_PREGRAD_PASSES:
|
||||
ctx: Any = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = patch(
|
||||
"torch._inductor.compile_fx._recursive_pre_grad_passes",
|
||||
lambda gm, _: gm,
|
||||
)
|
||||
with ctx:
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
assert isinstance(compiled_graph, AOTCompiledArtifact)
|
||||
assert hasattr(compiled_graph, "serialize")
|
||||
# just return the compiled graph and a key
|
||||
# since we can serialize the bytes using to_bytes
|
||||
# and reload it using the key when reading
|
||||
return compiled_graph, None
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
|
||||
def is_saveable_2_10(compiled_artifact):
|
||||
# can just use compiled_artifact.is_saveable in 2.11
|
||||
if compiled_artifact._artifacts is None:
|
||||
return False
|
||||
_, cache_info = compiled_artifact._artifacts
|
||||
return len(cache_info.aot_autograd_artifacts) == 1
|
||||
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
if not is_saveable_2_10(compiled_graph):
|
||||
raise RuntimeError(
|
||||
"The compiled artifact is not serializable. This usually means "
|
||||
"that the model code has something that is not serializable "
|
||||
"by torch.compile in it. You can fix this by either "
|
||||
"figuring out what is not serializable and rewriting it, "
|
||||
"filing a bug report, "
|
||||
"or suppressing this error by "
|
||||
"disabling vLLM's compilation cache via "
|
||||
"VLLM_DISABLE_COMPILE_CACHE=1 "
|
||||
"(this will greatly increase vLLM server warm start times)."
|
||||
)
|
||||
compiled_graph.save(path=path, format=self.save_format)
|
||||
compilation_counter.num_compiled_artifacts_saved += 1
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format=self.save_format
|
||||
)
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a subdirectory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
|
||||
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if (
|
||||
not file_path.startswith(self.base_cache_dir)
|
||||
and compiled_fn.__closure__ is not None
|
||||
):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash,
|
||||
)
|
||||
)
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache,
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
# Disable remote caching. When these are on, on remote cache-hit,
|
||||
# the monkey-patched functions never actually get called.
|
||||
# vLLM today assumes and requires the monkey-patched functions to
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config,
|
||||
)
|
||||
|
||||
# Turn off the checks if we disable the compilation cache.
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
if hash_str is None:
|
||||
raise RuntimeError(
|
||||
"vLLM failed to compile the model. The most "
|
||||
"likely reason for this is that a previous compilation "
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. "
|
||||
)
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph"
|
||||
)
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove "
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different mode of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
|
||||
if compile_range.is_single_size():
|
||||
# for a specific batch size, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
|
||||
)
|
||||
|
||||
|
||||
def set_functorch_config() -> None:
|
||||
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
50
vllm/compilation/counter.py
Normal file
50
vllm/compilation/counter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_models_seen: int = 0
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
# The number of time vLLM's compiler cache entry was updated
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}"
|
||||
)
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
332
vllm/compilation/cuda_graph.py
Normal file
332
vllm/compilation/cuda_graph.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CUDAGraphStat:
|
||||
num_unpadded_tokens: int
|
||||
num_padded_tokens: int
|
||||
num_paddings: int
|
||||
runtime_mode: str
|
||||
|
||||
|
||||
class CUDAGraphLogging:
|
||||
"""Aggregate and log cudagraph metrics"""
|
||||
|
||||
COLUMN_HEADERS = [
|
||||
"Unpadded Tokens",
|
||||
"Padded Tokens",
|
||||
"Num Paddings",
|
||||
"Runtime Mode",
|
||||
"Count",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
|
||||
) -> None:
|
||||
self.reset()
|
||||
self.cg_mode = str(cg_mode)
|
||||
self.cg_capture_sizes = str(cg_capture_sizes or [])
|
||||
|
||||
self.settings_header = (
|
||||
"**CUDAGraph Config Settings:**\n\n"
|
||||
f"- Mode: {self.cg_mode}\n"
|
||||
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
|
||||
"**CUDAGraph Stats:**\n\n"
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.stats: list[CUDAGraphStat] = []
|
||||
|
||||
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
|
||||
self.stats.append(cudagraph_stat)
|
||||
|
||||
def generate_metric_table(self) -> str:
|
||||
stats_counts = Counter(self.stats)
|
||||
|
||||
# Convert stats to rows of strings, in descending order of observed frequencies
|
||||
rows = []
|
||||
for stat, count in sorted(
|
||||
stats_counts.items(), key=lambda item: item[1], reverse=True
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
str(stat.num_unpadded_tokens),
|
||||
str(stat.num_padded_tokens),
|
||||
str(stat.num_paddings),
|
||||
stat.runtime_mode,
|
||||
str(count),
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate column widths (max of header and data)
|
||||
col_widths = []
|
||||
for i, header_text in enumerate(self.COLUMN_HEADERS):
|
||||
max_width = len(header_text)
|
||||
for row in rows:
|
||||
max_width = max(max_width, len(row[i]))
|
||||
col_widths.append(max_width)
|
||||
|
||||
table_header_list = [
|
||||
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
|
||||
]
|
||||
table_header = "| " + " | ".join(table_header_list) + " |\n"
|
||||
|
||||
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
|
||||
|
||||
# Create data rows with proper alignment
|
||||
data_rows = []
|
||||
for row in rows:
|
||||
formatted_row = [
|
||||
str(val).ljust(width) for val, width in zip(row, col_widths)
|
||||
]
|
||||
data_rows.append("| " + " | ".join(formatted_row) + " |")
|
||||
|
||||
return (
|
||||
self.settings_header
|
||||
+ table_header
|
||||
+ table_separator
|
||||
+ "\n".join(data_rows)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
|
||||
if not self.stats:
|
||||
return
|
||||
log_fn(self.generate_metric_table())
|
||||
self.reset()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
cudagraph: torch.cuda.CUDAGraph | None = None
|
||||
output: Any | None = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: list[int] | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
) -> None:
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(
|
||||
f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}"
|
||||
)
|
||||
|
||||
def unwrap(self) -> Callable[..., Any]:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def weak_ref_tensors_with_intermediate(self, output):
|
||||
if isinstance(output, IntermediateTensors):
|
||||
intermediate_states = IntermediateTensors(
|
||||
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
|
||||
return intermediate_states
|
||||
return weak_ref_tensors(output)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||
or cudagraph_runtime_mode != self.runtime_mode
|
||||
):
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
|
||||
batch_descriptor=batch_descriptor
|
||||
)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug(
|
||||
"Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_descriptor,
|
||||
)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(
|
||||
cudagraph,
|
||||
pool=self.graph_pool,
|
||||
stream=current_stream(),
|
||||
):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
# Join offloader's copy stream after forward to avoid
|
||||
# unjoined stream error. The last layer's start_prefetch
|
||||
# forks copy_stream, but wait_prefetch only happens in
|
||||
# the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
# output = weak_ref_tensors(output)
|
||||
output = self.weak_ref_tensors_with_intermediate(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
# entry.output = weak_ref_tensors(output)
|
||||
entry.output = self.weak_ref_tensors_with_intermediate(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
# Sync offloader before replay - ensures any external dependencies
|
||||
# from pre-capture prefetches are satisfied.
|
||||
get_offloader().sync_prev_onload()
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
657
vllm/compilation/decorators.py
Normal file
657
vllm/compilation/decorators.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Only added on nightly/2.10 so wrap
|
||||
try:
|
||||
from torch._dynamo.package import SourceInfo
|
||||
except ImportError:
|
||||
# Fallback for old versions not supporting
|
||||
SourceInfo = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=nn.Module)
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
"""
|
||||
setattr(cls, IGNORE_COMPILE_KEY, True)
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: type[_T] | None = None,
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> Callable[[type[_T]], type[_T]] | type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
Usage 1: use directly as a decorator without arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||
of the `forward` method, based on the following default rules:
|
||||
|
||||
- if the argument is annotated as `torch.Tensor` or
|
||||
`Optional[torch.Tensor]`, the first dimension will be
|
||||
marked as dynamic.
|
||||
- if the argument is annotated as `IntermediateTensors`, the first
|
||||
dimension of all the tensors in the intermediate tensors
|
||||
will be marked as dynamic.
|
||||
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
|
||||
`enable_if` is a function that takes a `VllmConfig` object as input and
|
||||
returns a boolean value indicating whether to compile the model or not.
|
||||
This is useful if you want to compile the model only when certain
|
||||
conditions are met.
|
||||
|
||||
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
|
||||
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
|
||||
such as for vision model compilation
|
||||
|
||||
`shape_invariants` is a function that gets compiled right before forward.
|
||||
The function should have the torch._check calls that are needed to set
|
||||
the relationships between different input sizes. For example:
|
||||
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
|
||||
This enforces constraints on the symbolic shapes without hardcoding
|
||||
specific values. It is needed for some models to avoid data dependent
|
||||
errors.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: type[_T]) -> type[_T]:
|
||||
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
|
||||
# to avoid too much indentation for `_support_torch_compile`
|
||||
if not hasattr(cls, "forward"):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
if inferred_dynamic_arg_dims is None:
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
IntermediateTensors,
|
||||
IntermediateTensors | None,
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(
|
||||
("Inferred dynamic dimensions for forward method of %s: %s"),
|
||||
cls,
|
||||
list(inferred_dynamic_arg_dims.keys()),
|
||||
)
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly."
|
||||
)
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(
|
||||
cls,
|
||||
inferred_dynamic_arg_dims,
|
||||
mark_unbacked_dims,
|
||||
enable_if,
|
||||
shape_invariants,
|
||||
)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
assert isinstance(cls, type)
|
||||
return cls_decorator_helper(cls)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _model_hash_key(fn: Callable[..., Any]) -> str:
|
||||
import vllm
|
||||
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(vllm.__version__.encode())
|
||||
sha256_hash.update(fn.__qualname__.encode())
|
||||
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def _verify_source_unchanged(
|
||||
source_info: "SourceInfo", vllm_config: VllmConfig
|
||||
) -> None:
|
||||
from .caching import _compute_code_hash, _compute_code_hash_with_content
|
||||
|
||||
file_contents = {}
|
||||
for source in source_info.inlined_sources:
|
||||
module = sys.modules[source.module]
|
||||
file = inspect.getfile(module)
|
||||
vllm_config.compilation_config.traced_files.add(file)
|
||||
file_contents[file] = source.content
|
||||
expected_checksum = _compute_code_hash_with_content(file_contents)
|
||||
actual_checksum = _compute_code_hash(set(file_contents.keys()))
|
||||
if expected_checksum != actual_checksum:
|
||||
raise RuntimeError(
|
||||
"Source code has changed since the last compilation. Recompiling the model."
|
||||
)
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: type[_T],
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWithNoGuardsWrapper
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(
|
||||
self: _T,
|
||||
*,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if vllm_config is None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# NOTE: to support multimodal models (such as encoder),
|
||||
# we may not have vllm_config so we may need to patch
|
||||
# it
|
||||
sig = inspect.signature(old_init)
|
||||
if "vllm_config" in sig.parameters:
|
||||
kwargs["vllm_config"] = vllm_config
|
||||
if "prefix" in sig.parameters:
|
||||
kwargs["prefix"] = prefix
|
||||
old_init(self, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = (
|
||||
self.compilation_config.mode
|
||||
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
or not enable_compile
|
||||
)
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
self._check_shape_invariants = shape_invariants
|
||||
self.was_aot_compile_fn_loaded_from_disk = False
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
|
||||
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(
|
||||
mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
|
||||
# enc-dec models where tensor shapes/types vary across invocations, preventing
|
||||
# the capture of a single computational graph.
|
||||
if is_forward_context_available() and get_forward_context().skip_compiled:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, call it with partition wrapper context.
|
||||
# The partition wrapper must be active at runtime for CUDA graph
|
||||
# capture to work correctly with inductor graph partitioning.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
ds_type = self.compilation_config.dynamic_shapes_config.type
|
||||
cache_dir = None
|
||||
aot_compilation_path = None
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
"""
|
||||
When using torch.compile in AOT mode, we store the cache artifacts
|
||||
under VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}
|
||||
The {hash} contains all of the factors except for the source files
|
||||
being traced through, because we don't actually know which source
|
||||
files to check at this point (before dynamo runs).
|
||||
On loading we will actually look at the source files being traced
|
||||
through. If any source file have changed (compared with the
|
||||
serialized backend artifacts), then we need to generate a new AOT
|
||||
compile artifact from scratch.
|
||||
"""
|
||||
from .caching import aot_compile_hash_factors
|
||||
|
||||
factors: list[str] = aot_compile_hash_factors(self.vllm_config)
|
||||
|
||||
factors.append(_model_hash_key(self.forward))
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
"torch_aot_compile",
|
||||
hash_key,
|
||||
)
|
||||
|
||||
rank = self.vllm_config.parallel_config.rank
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_index
|
||||
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
aot_compilation_path = os.path.join(cache_dir, "model")
|
||||
try:
|
||||
with (
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
):
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
loaded_fn = torch.compiler.load_compiled_function(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
self.was_aot_compile_fn_loaded_from_disk = True
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
if isinstance(e, EOFError):
|
||||
message = "Compile cache file corrupted."
|
||||
else:
|
||||
message = str(e)
|
||||
logger.warning(
|
||||
"Compiling model again due to a load failure from %s, "
|
||||
"reason: %s",
|
||||
aot_compilation_path,
|
||||
message,
|
||||
)
|
||||
if envs.VLLM_FORCE_AOT_LOAD:
|
||||
raise e
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
# Apply partition wrapper context for proper CUDA graph capture
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert (
|
||||
not envs.VLLM_USE_AOT_COMPILE
|
||||
or self.vllm_config.compilation_config.backend == "eager"
|
||||
)
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
# This is the path for the first compilation.
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(
|
||||
self,
|
||||
ds_type,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
original_code_object = self.original_code_object()
|
||||
logger.debug("Start compiling function %s", original_code_object)
|
||||
|
||||
# we do not want tp delete the original code object entries since
|
||||
# we depend on them now to look up cached compiled functions.
|
||||
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.compilation_config.traced_files.add(original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
|
||||
def patched_inline_call(self_: Any) -> Any:
|
||||
code = self_.f_code
|
||||
self.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
# Prepare backed_size_oblivious config patch if needed
|
||||
fx_config_patches = {}
|
||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
# Prepare inductor config patches
|
||||
# assume_32bit_indexing is only available in torch 2.10.0+
|
||||
inductor_config_patches = {}
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
inductor_config_patches["assume_32bit_indexing"] = (
|
||||
self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
torch._inductor.config.patch(**inductor_config_patches),
|
||||
):
|
||||
use_aot_compile = envs.VLLM_USE_AOT_COMPILE
|
||||
if self.vllm_config.compilation_config.backend == "eager":
|
||||
logger.warning("Detected eager backend, disabling AOT compile.")
|
||||
use_aot_compile = False
|
||||
if use_aot_compile:
|
||||
from vllm.compilation.backends import set_on_compilation_complete
|
||||
|
||||
# store the path for saving after warmup
|
||||
self._aot_compilation_path = aot_compilation_path
|
||||
self._aot_cache_dir = cache_dir
|
||||
# set callback in context so it's available when compilation completes
|
||||
with set_on_compilation_complete(self.save_aot_compiled_function):
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
# triggers VllmSerializableFunction.serialize()
|
||||
def save_aot_compiled_function(self: type[_T]) -> None:
|
||||
if self.was_aot_compile_fn_loaded_from_disk:
|
||||
logger.debug("AOT compiled function was loaded from cache, skipping save")
|
||||
return
|
||||
|
||||
assert (
|
||||
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
|
||||
)
|
||||
|
||||
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
|
||||
try:
|
||||
os.makedirs(self._aot_cache_dir, exist_ok=True)
|
||||
# File saving should be atomic, so we will save to a temporary location
|
||||
# first. Should be upstreamed to PyTorch 2.12 as well.
|
||||
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
|
||||
self.aot_compiled_fn.save_compiled_function(tmp_file)
|
||||
os.replace(tmp_file, self._aot_compilation_path)
|
||||
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"unable to save AOT compiled function to %s: %s",
|
||||
self._aot_compilation_path,
|
||||
e,
|
||||
)
|
||||
|
||||
cls.__call__ = __call__
|
||||
cls.save_aot_compiled_function = save_aot_compiled_function
|
||||
return cls
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maybe_use_cudagraph_partition_wrapper(
|
||||
vllm_config: VllmConfig,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager to set/unset customized cudagraph partition wrappers.
|
||||
|
||||
If we're using Inductor-based graph partitioning, we currently have the
|
||||
whole `fx.Graph` before Inductor lowering and the piecewise
|
||||
splitting happens after all graph passes and fusions. Here, we add
|
||||
a custom hook for Inductor to wrap each partition with our static
|
||||
graph wrapper class to maintain more control over static graph
|
||||
capture and replay.
|
||||
"""
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
def customized_cudagraph_wrapper(
|
||||
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
|
||||
) -> Any:
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
runnable=f,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=partition_id == 0,
|
||||
gc_disable=partition_id != 0,
|
||||
weak_ref_output=partition_id == num_partitions - 1,
|
||||
),
|
||||
)
|
||||
|
||||
torch._inductor.utils.set_customized_partition_wrappers(
|
||||
customized_cudagraph_wrapper
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
63
vllm/compilation/monitor.py
Normal file
63
vllm/compilation/monitor.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.perf_counter()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug("Dumping depyf output to %s", path)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
total_compile_time: float = time.perf_counter() - torch_compile_start_time
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info_once(
|
||||
"torch.compile takes %.2f s in total",
|
||||
total_compile_time,
|
||||
scope="local",
|
||||
)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled() -> None:
|
||||
# used to monitor whether a cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError(
|
||||
"CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled."
|
||||
)
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
75
vllm/compilation/partition_rules.py
Normal file
75
vllm/compilation/partition_rules.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||
"""
|
||||
Check if a node should be split for dynamo graph partition.
|
||||
It operates on dynamo graph, so the node.target can be anything.
|
||||
We need to check and split only on OpOverload and OpOverloadPacket.
|
||||
"""
|
||||
|
||||
if node.op != "call_function":
|
||||
return False
|
||||
|
||||
target = node.target
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
# Example: "aten::add"
|
||||
return target._qualified_op_name in splitting_ops
|
||||
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Example: "aten::add"
|
||||
packet_name = target.name()
|
||||
|
||||
# Example: "aten::add.default"
|
||||
op_overload_name = f"{packet_name}.{target._overloadname}"
|
||||
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inductor_partition_rule_context(
|
||||
splitting_ops: list[str] | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to temporarily register Inductor partition rules.
|
||||
|
||||
Registers custom partition rules for specified operators, forcing the
|
||||
Inductor scheduler to partition the graph at these operators. The rules
|
||||
are automatically restored to their previous state on exit.
|
||||
|
||||
Args:
|
||||
splitting_ops: List of operator names to partition on.
|
||||
"""
|
||||
if not splitting_ops:
|
||||
logger.debug("No partition ops provided; skipping rule registration.")
|
||||
yield
|
||||
return
|
||||
|
||||
# Save current state before registering
|
||||
|
||||
saved_splitting_ops: list[str] = list(
|
||||
torch._inductor.config.custom_should_partition_ops
|
||||
)
|
||||
torch._inductor.config.custom_should_partition_ops = splitting_ops
|
||||
|
||||
logger.debug(
|
||||
"Registered inductor partition rules for %d operators", len(splitting_ops)
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Clear and restore previous state
|
||||
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
|
||||
logger.debug("Restored previous partition rules state.")
|
||||
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
215
vllm/compilation/passes/fusion/act_quant_fusion.py
Normal file
215
vllm/compilation/passes/fusion/act_quant_fusion.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (
|
||||
PatternMatcherPass,
|
||||
fwd_only,
|
||||
register_replacement,
|
||||
)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
)
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Activation+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
) -> None:
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, (
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
)
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kFp8StaticTensorSym)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
scale = self.quant_matcher.inputs()[1]
|
||||
return [
|
||||
*self.silu_and_mul_matcher.inputs(), # input
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
result_quant = self.quant_matcher(result_silu_mul, scale)
|
||||
return result_quant[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = input.shape[-1] // 2
|
||||
output_shape = input.shape[:-1] + (d,)
|
||||
result = torch.empty(
|
||||
output_shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return at[1]
|
||||
|
||||
inps = self.get_inputs()
|
||||
pattern(*inps)
|
||||
|
||||
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kNvfp4Dynamic)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
result = self.empty_quant(5, 32)
|
||||
output_scale = empty_i32(128, 4)
|
||||
input_ = empty_bf16(5, 64)
|
||||
scale = empty_fp32(1, 1)
|
||||
return [result, output_scale, input_, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass"
|
||||
)
|
||||
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern,
|
||||
)
|
||||
862
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Normal file
862
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Normal file
@@ -0,0 +1,862 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from importlib.util import find_spec
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
flashinfer_comm: ModuleType | None = None
|
||||
if find_spec("flashinfer"):
|
||||
try:
|
||||
import flashinfer.comm as _flashinfer_comm
|
||||
|
||||
if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
|
||||
_flashinfer_comm, "create_allreduce_fusion_workspace"
|
||||
):
|
||||
flashinfer_comm = _flashinfer_comm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 64, # 64MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 64, # 64MB
|
||||
4: 32, # 32MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer one shot fused allreduce
|
||||
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
|
||||
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 32, # 32MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 32, # 32MB
|
||||
4: 4, # 4MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
destroy_fi_ar_workspace,
|
||||
get_fi_ar_quant_workspace,
|
||||
get_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
initialize_fi_ar_workspace,
|
||||
)
|
||||
|
||||
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
|
||||
|
||||
MiB = 1024 * 1024
|
||||
|
||||
def call_trtllm_fused_allreduce_norm(
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
max_tensor_size = max_token_num * hidden_size * element_size
|
||||
assert current_tensor_size <= max_tensor_size, (
|
||||
f"Current tensor size {current_tensor_size} is larger than "
|
||||
f"max token num {max_token_num} * hidden size {hidden_size} * "
|
||||
f"element size {element_size}"
|
||||
)
|
||||
curr_device = current_platform.get_device_capability()
|
||||
device_capability = curr_device.to_int() if curr_device is not None else None
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, # type: ignore[arg-type, unused-ignore]
|
||||
{},
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
|
||||
)
|
||||
|
||||
# Select workspace based on pattern: quant patterns use the
|
||||
# trtllm quant workspace, non-quant patterns use the primary workspace.
|
||||
if pattern_code in (
|
||||
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
|
||||
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
|
||||
):
|
||||
workspace = get_fi_ar_quant_workspace()
|
||||
else:
|
||||
workspace = get_fi_ar_workspace()
|
||||
assert workspace is not None, (
|
||||
"Flashinfer workspace must be initialized when using flashinfer"
|
||||
)
|
||||
assert flashinfer_comm is not None
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
else:
|
||||
# return residual_out as allreduce_out with zeroed residual_in
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
|
||||
layout_code = None
|
||||
# layout_code only supported by trtllm backend
|
||||
if workspace.backend == "trtllm":
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=allreduce_in,
|
||||
workspace=workspace,
|
||||
pattern=pattern_code,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
output=None,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
residual_in=residual,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
scale_factor=scale_factor,
|
||||
layout_code=layout_code,
|
||||
use_oneshot=use_oneshot,
|
||||
fp32_acc=fp32_acc,
|
||||
)
|
||||
|
||||
def call_trtllm_fused_allreduce_norm_fake(
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_trtllm_fused_allreduce_norm",
|
||||
op_func=call_trtllm_fused_allreduce_norm,
|
||||
mutates_args=[
|
||||
"allreduce_in",
|
||||
"residual",
|
||||
"norm_out",
|
||||
"quant_out",
|
||||
"scale_out",
|
||||
],
|
||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||
)
|
||||
flashinfer_trtllm_fused_allreduce_norm = (
|
||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFusedAllReduceParams:
|
||||
"""Parameters for FlashInfer fused allreduce operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
max_token_num: int = 1024,
|
||||
) -> None:
|
||||
self.world_size = world_size
|
||||
self.launch_with_pdl = True
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
|
||||
return {
|
||||
"world_size": self.world_size,
|
||||
"launch_with_pdl": self.launch_with_pdl,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
"max_token_num": self.max_token_num,
|
||||
}
|
||||
|
||||
|
||||
# TODO(luka): unify
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm before attn in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
||||
|
||||
return rms, allreduce_output
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=rms_result,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# rms_result, allreduce_in
|
||||
return allreduce[3], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
return rms, residual
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# allreduce_in, residual
|
||||
return allreduce[1], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
# Same pattern, but only return the output and not residual
|
||||
# (helpful for end of graph where residual is not used again)
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
|
||||
),
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return allreduce[4], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
|
||||
return quant, res
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
|
||||
),
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# quant_out, rms_norm_residual
|
||||
return allreduce[4], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [input, quant_result, weight, input_global_scale, output_scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
|
||||
),
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return allreduce[4], allreduce[1], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
output_scale,
|
||||
weight,
|
||||
input_global_scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], residual, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
|
||||
),
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# quant_out, rms_norm_residual, output_scale
|
||||
return allreduce[4], allreduce[2], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.disabled = True
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size <= 1:
|
||||
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
|
||||
return
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="all_reduce_fusion_pass"
|
||||
)
|
||||
if config.model_config is None:
|
||||
logger.warning_once(
|
||||
"AllReduce fusion pass is disabled for missing model_config."
|
||||
)
|
||||
return
|
||||
self.hidden_dim = config.model_config.get_hidden_size()
|
||||
self.group = get_tp_group().device_group
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
if flashinfer_comm is None:
|
||||
logger.warning(
|
||||
"Flashinfer is not installed or comm module not found, "
|
||||
"skipping allreduce fusion pass"
|
||||
)
|
||||
return
|
||||
max_size = config.compilation_config.pass_config.flashinfer_max_size(
|
||||
self.tp_size
|
||||
)
|
||||
if max_size is None:
|
||||
# Flashinfer doesn't support current world size
|
||||
logger.warning(
|
||||
"Flashinfer allreduce fusion is not supported for world size %s"
|
||||
" or max size is not provided",
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
element_size = torch.tensor([], dtype=self.model_dtype).element_size()
|
||||
self.max_token_num = max_size // (self.hidden_dim * element_size)
|
||||
# take the min to save workspace size and we'll never use more
|
||||
# than max_num_batched_tokens anyways
|
||||
self.max_token_num = min(
|
||||
self.max_token_num, config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
logger.debug_once(
|
||||
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
|
||||
"Maximal number of tokens used by "
|
||||
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
for workspace_init_fn in [
|
||||
initialize_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
]:
|
||||
try:
|
||||
workspace_init_fn(
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
group=self.group,
|
||||
)
|
||||
except Exception as e:
|
||||
if "multicast" in str(e).lower():
|
||||
logger.warning(
|
||||
"AllReduce fusion pass is disabled: flashinfer workspace "
|
||||
"creation failed: %s. This is expected on GPUs without "
|
||||
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
|
||||
"Falling back to non-fused allreduce.",
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to initialize FlashInfer All Reduce workspace: %s. "
|
||||
"AllReduce fusion pass will be disabled.",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
self.allreduce_params = FlashInferFusedAllReduceParams(
|
||||
world_size=self.tp_size,
|
||||
max_token_num=self.max_token_num,
|
||||
)
|
||||
|
||||
self.register_patterns()
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@enable_fake_mode
|
||||
def register_patterns(self) -> None:
|
||||
supports_quantization = get_fi_ar_quant_workspace() is not None
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
if supports_quantization:
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
if current_platform.has_device_capability(100):
|
||||
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
self.disabled = False
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
if self.disabled:
|
||||
logger.warning_once("AllReduce fusion pass is disabled.")
|
||||
return False
|
||||
return bool(compile_range.end <= self.max_token_num)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
if self.disabled:
|
||||
logger.debug("AllReduceFusionPass disabled")
|
||||
return
|
||||
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "disabled", True):
|
||||
return
|
||||
with contextlib.suppress(Exception):
|
||||
destroy_fi_ar_workspace()
|
||||
374
vllm/compilation/passes/fusion/attn_quant_fusion.py
Normal file
374
vllm/compilation/passes/fusion/attn_quant_fusion.py
Normal file
@@ -0,0 +1,374 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
P = ParamSpec("P")
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
|
||||
dims = node.args[1]
|
||||
if any(dim != i for i, dim in enumerate(dims)):
|
||||
continue
|
||||
|
||||
# this is now an identity op, remove
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
|
||||
return self.quant_matcher(attn_out_view, scale)[0]
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size), # q
|
||||
self.empty(5, self.num_heads, self.head_size), # k
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
|
||||
super().__init__(layer, kNvfp4Dynamic, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
|
||||
empty_i32(
|
||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||
), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered."
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern,
|
||||
)
|
||||
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(
|
||||
input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(
|
||||
all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass"
|
||||
)
|
||||
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
472
vllm/compilation/passes/fusion/matcher_utils.py
Normal file
472
vllm/compilation/passes/fusion/matcher_utils.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
||||
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
class MatcherCustomOp(ABC):
|
||||
def __init__(self, enabled: bool) -> None:
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
self.enabled = enabled
|
||||
self.forward = self.forward_custom if enabled else self.forward_native
|
||||
|
||||
@abstractmethod
|
||||
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
|
||||
|
||||
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
"""Utility for inputs to the pattern"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
is_neox: bool,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_flashinfer: bool = False,
|
||||
match_rocm_aiter: bool | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RotaryEmbedding.enabled()
|
||||
if match_rocm_aiter is None:
|
||||
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.is_neox = is_neox
|
||||
self.head_size = head_size
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.kv_size = self.num_kv_heads * self.head_size
|
||||
self.rotary_dim = head_size
|
||||
if use_flashinfer:
|
||||
self.rotary_op = FLASHINFER_ROTARY_OP
|
||||
elif match_rocm_aiter:
|
||||
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
|
||||
else:
|
||||
self.rotary_op = ROTARY_OP
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
positions = self.empty_int64(5)
|
||||
query = self.empty(5, self.q_size)
|
||||
key = self.empty(5, self.kv_size)
|
||||
cos_sin_cache = self.empty(4096, self.rotary_dim)
|
||||
return [positions, query, key, cos_sin_cache]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result = auto_functionalized(
|
||||
self.rotary_op,
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
)
|
||||
query_out = result[1]
|
||||
key_out = result[2] if len(result) > 2 else None
|
||||
return query_out, key_out
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result: tuple[torch.Tensor, torch.Tensor | None] = (
|
||||
RotaryEmbedding.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self._rmsnorm_op = RMS_OP
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self._rmsnorm_op(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight)
|
||||
|
||||
result = torch.empty_like(input)
|
||||
_, result = auto_functionalized(
|
||||
self._rmsnorm_op,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
||||
)
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
self._rmsnorm_op = RMS_ADD_OP
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(5, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._rmsnorm_op( # type: ignore[no-any-return]
|
||||
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight, residual)
|
||||
|
||||
_, result, residual = auto_functionalized(
|
||||
self._rmsnorm_op,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
enabled: bool | None = None,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
match_rocm_aiter: bool = False,
|
||||
is_tma_aligned: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
|
||||
if match_rocm_aiter:
|
||||
assert not quant_key.scale.group_shape.is_per_tensor(), (
|
||||
"ROCm aiter fusion pass does not support per tensor quantization"
|
||||
)
|
||||
if quant_key.scale.group_shape.is_per_token():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
|
||||
else:
|
||||
assert quant_key.scale.group_shape.col == 128, (
|
||||
"ROCm aiter fusion pass currently supports "
|
||||
"quantization operation with group_size 128"
|
||||
)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
else:
|
||||
self.QUANT_OP = (
|
||||
torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
)
|
||||
|
||||
else:
|
||||
assert quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
|
||||
self.quant_fp8 = QuantFP8(
|
||||
quant_key.scale.static,
|
||||
quant_key.scale.group_shape,
|
||||
column_major_scales=has_col_major_scales,
|
||||
use_ue8m0=is_e8m0,
|
||||
tma_aligned_scales=self.is_tma_aligned,
|
||||
compile_native=False,
|
||||
)
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
quant_key_group_shape = self.quant_key.scale.group_shape
|
||||
if quant_key_group_shape == GroupShape.PER_TOKEN:
|
||||
return self.QUANT_OP( # type: ignore[no-any-return]
|
||||
x=input,
|
||||
quant_dtype=self.quant_key.dtype,
|
||||
scale=scale,
|
||||
)
|
||||
else:
|
||||
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, scale)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.group_shape.is_per_group():
|
||||
# for tma_aligned, the scale must be passed to forward_custom
|
||||
# tma_aligned fusion then matches by custom op arguments
|
||||
if not self.is_tma_aligned:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input, transposed=self.has_col_major_scales)
|
||||
|
||||
finfo = torch.finfo(self.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
input=input,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
return result, scale
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return result, scale
|
||||
else:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input)
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
|
||||
)
|
||||
return result, scale
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
|
||||
|
||||
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
scale_shape = (
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
if transposed:
|
||||
scale_shape = tuple(reversed(scale_shape))
|
||||
return torch.empty(
|
||||
scale_shape, device=input.device, dtype=torch.float32
|
||||
).permute(-1, -2)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16)
|
||||
if self.quant_key.scale.static:
|
||||
return [input, self.empty_f32(1, 1)]
|
||||
|
||||
return [input]
|
||||
|
||||
|
||||
class MatcherSiluAndMul(MatcherCustomOp):
|
||||
def __init__(self, enabled: bool | None = None) -> None:
|
||||
if enabled is None:
|
||||
enabled = SiluAndMul.enabled()
|
||||
super().__init__(enabled)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 4)
|
||||
return [input]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
|
||||
return result[1]
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return SiluAndMul.forward_native(x)
|
||||
244
vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Normal file
244
vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class QkNormRopePattern:
|
||||
"""
|
||||
Match the unfused sequence in attention blocks and replace with the fused op.
|
||||
|
||||
Unfused (conceptually):
|
||||
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
qh = reshape(q, [-1, num_heads, head_dim])
|
||||
kh = reshape(k, [-1, num_kv_heads, head_dim])
|
||||
qn = rms_norm(qh, q_weight, eps)
|
||||
kn = rms_norm(kh, k_weight, eps)
|
||||
qf = reshape(qn, [-1, num_heads * head_dim])
|
||||
kf = reshape(kn, [-1, num_kv_heads * head_dim])
|
||||
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
|
||||
return qf, kf, v
|
||||
|
||||
Fused replacement:
|
||||
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
|
||||
eps, q_weight, k_weight, cos_sin_cache, is_neox,
|
||||
positions.view(-1))
|
||||
return split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
rope_flashinfer: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(eps)
|
||||
self.is_neox = is_neox
|
||||
self.rope_flashinfer = rope_flashinfer
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=is_neox,
|
||||
head_size=self.head_dim,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_flashinfer=self.rope_flashinfer,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
|
||||
positions = empty_i64(T)
|
||||
q_weight = empty_bf16(1, self.head_dim)
|
||||
k_weight = empty_bf16(1, self.head_dim)
|
||||
if self.rope_flashinfer:
|
||||
cos_sin_cache = empty_fp32(4096, self.head_dim)
|
||||
else:
|
||||
cos_sin_cache = empty_bf16(4096, self.head_dim)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
q_weight,
|
||||
k_weight,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# split qkv -> q,k,v
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Q path: view -> RMS -> view back to q.shape
|
||||
q_by_head = q.view(
|
||||
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
|
||||
q_flat = q_normed_by_head.view(q.shape)
|
||||
|
||||
# K path: view -> RMS -> view back to k.shape
|
||||
k_by_head = k.view(
|
||||
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
|
||||
k_flat = k_normed_by_head.view(k.shape)
|
||||
|
||||
# RoPE: apply to flattened q/k
|
||||
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Run fused qk_norm_rope op
|
||||
result = auto_functionalized(
|
||||
FUSED_QK_ROPE_OP,
|
||||
qkv=qkv,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=self.num_kv_heads,
|
||||
num_heads_v=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
eps=self.eps,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
position_ids=positions.view(-1),
|
||||
)
|
||||
result_qkv = result[1]
|
||||
|
||||
# Split back to q,k,v and return
|
||||
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
|
||||
|
||||
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.get_inputs(),
|
||||
QkNormRopePattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
QkNormRopePattern.fx_view_to_reshape,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
|
||||
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qk_norm_rope_fusion_pass"
|
||||
)
|
||||
|
||||
dtype = config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
|
||||
)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
config, Attention
|
||||
)
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for neox in [True, False]:
|
||||
if RotaryEmbedding.enabled():
|
||||
for rope_flashinfer in [False, True]:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
rope_flashinfer=rope_flashinfer,
|
||||
).register(self.patterns)
|
||||
else:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, QkNormRopePattern)
|
||||
643
vllm/compilation/passes/fusion/rms_quant_fusion.py
Normal file
643
vllm/compilation/passes/fusion/rms_quant_fusion.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of RMSNorm + quant fusion.
|
||||
quant: type of quantization
|
||||
fused_add: does the op also perform the residual add
|
||||
"""
|
||||
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)"
|
||||
)
|
||||
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, False
|
||||
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, True
|
||||
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, False
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, True
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
key: FusedRMSQuantKey,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
is_tma_aligned: bool = False,
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
fused_key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
# input, weight, residual
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric: bool = True,
|
||||
is_e8m0: bool = False,
|
||||
has_col_major_scales: bool = True,
|
||||
is_tma_aligned: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
super().__init__(
|
||||
epsilon,
|
||||
key,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result = torch.empty(
|
||||
result_rms.shape,
|
||||
device=result_rms.device,
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
input=result_rms,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.quant_matcher.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs() + [scale],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric: bool = True,
|
||||
is_e8m0: bool = False,
|
||||
has_col_major_scales: bool = True,
|
||||
is_tma_aligned: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
super().__init__(
|
||||
epsilon,
|
||||
key,
|
||||
has_col_major_scales=self.has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result = torch.empty(
|
||||
result_rms.shape,
|
||||
device=result_rms.device,
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
input=result_rms,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.quant_matcher.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs() + [scale],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Only register group quant patterns on CUDA where the C++ op exists
|
||||
if current_platform.is_cuda():
|
||||
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
|
||||
for has_col_major_scales in [True, False]:
|
||||
for is_e8m0 in [True, False]:
|
||||
for is_tma_aligned in [False, True]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormGroupQuantPattern,
|
||||
RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormGroupQuantPattern,
|
||||
)
|
||||
504
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Normal file
504
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .act_quant_fusion import ActivationQuantPattern
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
MatcherSiluAndMul,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AiterRMSNormQuantPattern:
|
||||
def __init__(
|
||||
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
match_rocm_aiter=match_aiter_quant,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1], result[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||
ops into an aiter rms_norm_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
return at[0], at[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
# result, scale, residual
|
||||
return at[0], at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
|
||||
into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
|
||||
AiterRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
|
||||
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
for match_aiter_quant in [True, False]:
|
||||
# Fuse aiter rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterFusedAddRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
]
|
||||
return self.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter silu_and_mul & group fp8 quant custom
|
||||
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
self.silu_and_mul_matcher.inputs()[0],
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||
return at[0], at[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
ActivationQuantPattern,
|
||||
AiterSiluMulFp8GroupQuantPattern,
|
||||
]
|
||||
return VllmInductorPass.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AddAiterRMSNormPadPattern:
|
||||
"""
|
||||
This pattern replaces an aiter_rmsnorm_with_add & a pad op
|
||||
with a custom triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
hidden_size: int,
|
||||
x_pad_to_multiple: int,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.hidden_size = hidden_size
|
||||
self.x_pad_to_multiple = x_pad_to_multiple
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight, residual = self.rmsnorm_matcher.inputs()
|
||||
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
|
||||
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
|
||||
return [input, weight, residual, router_weight, router_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pad_size = self.x_pad_to_multiple - (
|
||||
self.hidden_size % self.x_pad_to_multiple
|
||||
)
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_rms, router_weight, router_bias
|
||||
)
|
||||
result = torch.nn.functional.pad(
|
||||
result_rms, (0, pad_size), mode="constant", value=0.0
|
||||
)
|
||||
return result, residual_out, router_logits
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
residual=residual,
|
||||
x_pad_to_multiple=self.x_pad_to_multiple,
|
||||
)
|
||||
result_padded = at[0]
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_padded[:, : self.hidden_size], router_weight, router_bias
|
||||
)
|
||||
residual_out = at[1]
|
||||
return result_padded, residual_out, router_logits
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass replaces an AITER CK RMSNorm + residual add and a pad op
|
||||
with an triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
|
||||
)
|
||||
|
||||
# gpt-oss has hidden size 2880
|
||||
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
|
||||
hidden_size = 2880
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for x_pad_to_multiple in [128, 256]:
|
||||
AddAiterRMSNormPadPattern(
|
||||
epsilon, hidden_size, x_pad_to_multiple
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)
|
||||
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.attention import (
|
||||
Attention,
|
||||
get_attention_context,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherRotaryEmbedding,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
empty_bf16,
|
||||
empty_i64,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This impl fetches the KV cache and slot mapping from the forward context,
|
||||
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
|
||||
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
|
||||
that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
attn_layer.impl.do_rope_and_kv_cache_update(
|
||||
attn_layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=query.device, dtype=query.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_rope_and_unified_kv_cache_update",
|
||||
op_func=fused_rope_and_unified_kv_cache_update_impl,
|
||||
mutates_args=["query", "key"],
|
||||
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
|
||||
)
|
||||
|
||||
|
||||
class RopeReshapeKVCachePattern:
|
||||
"""
|
||||
This pattern matches the following unfused inplace ops:
|
||||
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
|
||||
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
|
||||
|
||||
and replaces it with the fused inplace op:
|
||||
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
|
||||
q, k, v, positions, cos_sin_cache, is_neox, layer_name
|
||||
)
|
||||
"""
|
||||
|
||||
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.num_kv_heads = layer.num_kv_heads
|
||||
self.head_size = layer.head_size
|
||||
self.head_size_v = layer.head_size_v
|
||||
self.is_neox = is_neox
|
||||
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.k_size = self.num_kv_heads * self.head_size
|
||||
self.v_size = self.num_kv_heads * self.head_size_v
|
||||
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=self.is_neox,
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
L = 4096
|
||||
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
|
||||
positions = empty_i64(T)
|
||||
cos_sin_cache = empty_bf16(L, self.head_size)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
|
||||
return dummy, q, k, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
results = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
positions=positions,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
layer_name=self.layer_name,
|
||||
)
|
||||
return results[0], results[1], results[2], v
|
||||
|
||||
# NOTE: use view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
|
||||
gm = pm.fwd_only(*args, **kwargs)
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses the rotary embedding and KV cache update operations
|
||||
into a single fused kernel if available.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
This fusion eliminates the need for separate kernel launches and
|
||||
intermediate memory operations between the RoPE and cache update steps.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rope_kv_cache_fusion_pass"
|
||||
)
|
||||
|
||||
cc = config.compilation_config
|
||||
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for _, layer in attn_layers.items():
|
||||
if layer.impl.fused_rope_kvcache_supported():
|
||||
for is_neox in [True, False]:
|
||||
RopeReshapeKVCachePattern(
|
||||
layer=layer,
|
||||
is_neox=is_neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass works best for the small-batch decode setting.
|
||||
# For large-batch e.g. prefill, it is better to use two separate kernels
|
||||
# since they are compute bound and the fused kernels require further tuning.
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)
|
||||
452
vllm/compilation/passes/fusion/sequence_parallelism.py
Normal file
452
vllm/compilation/passes/fusion/sequence_parallelism.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Min hidden size per device capability for sequence parallelism
|
||||
# Only apply sequence parallelism for models with hidden_size >= threshold
|
||||
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
|
||||
90: 8192, # H100: only for models with hidden_size >= 8192
|
||||
}
|
||||
|
||||
# Min size per GPU per device capability for sequence parallelism
|
||||
# Total min size = min_per_gpu_size * tp_size
|
||||
# This ensures the threshold scales appropriately with tensor parallelism
|
||||
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
|
||||
90: 8, # 8MB per GPU for H100
|
||||
}
|
||||
|
||||
|
||||
def get_sequence_parallelism_threshold(
|
||||
hidden_size: int,
|
||||
tp_size: int,
|
||||
element_size: int,
|
||||
) -> int | None:
|
||||
"""
|
||||
Calculate the minimum token threshold for applying sequence parallelism.
|
||||
|
||||
Returns None if sequence parallelism should not be applied based on model size.
|
||||
|
||||
Branching logic based on device capability:
|
||||
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
|
||||
- If not, returns None (SP disabled for small models on this device)
|
||||
- If yes, calculates threshold based on per-GPU size
|
||||
|
||||
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
|
||||
(hidden_size * element_size)
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return None
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is None:
|
||||
return None
|
||||
device_capability = capability.to_int()
|
||||
|
||||
# Check if device has configured thresholds
|
||||
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
|
||||
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
|
||||
|
||||
if min_hidden_size is None or min_per_gpu_size_mb is None:
|
||||
return None
|
||||
|
||||
# Only apply sequence parallelism for models meeting the size threshold
|
||||
if hidden_size < min_hidden_size:
|
||||
return None
|
||||
|
||||
MiB = 1024 * 1024
|
||||
min_size = min_per_gpu_size_mb * MiB * tp_size
|
||||
return int(min_size // (hidden_size * element_size))
|
||||
|
||||
|
||||
def get_first_out_wrapper(
|
||||
fn: Callable[..., Sequence[torch.Tensor]],
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any) -> torch.Tensor:
|
||||
return fn(*args)[0]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
all_gather = self._all_gather(rmsnorm)
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
|
||||
return rmsnorm[0], rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
|
||||
all_gather = self._all_gather(rmsnorm[0])
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, rmsnorm[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [residual, mm_1, rms_norm_weights, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
all_reduce, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, residual_out
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# add a temporary slice which will become a noop
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
reduce_scatter, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, residual_out
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
an RMSNorm (or RMSNorm and then Quantization) operation.
|
||||
These patterns are replaced with a ReduceScatter operation, followed by
|
||||
a local RMSNorm/Quantization, and then an AllGather operation.
|
||||
|
||||
The general transformation is:
|
||||
Input -> AllReduce -> RMSNorm -> Output
|
||||
becomes
|
||||
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
|
||||
|
||||
While this pass itself does not directly yield performance improvements,
|
||||
it lays the groundwork for subsequent fusion passes, such as
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
|
||||
|
||||
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||
Because the pattern matcher starts at the end of the graph, the replacement
|
||||
contains a slice that temporarily conforms the input residual to the correct size.
|
||||
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||
what have now become no-op slices.
|
||||
|
||||
Note that an older version of the pass did not need this as it operated only on
|
||||
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||
mismatched shapes during replacement. So this approach has the same assumption that
|
||||
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||
|
||||
Correctness-wise, this is approach strictly better than before - before,
|
||||
the graph was incorrect semantically and shape-wise during the pass.
|
||||
With this approach there's only semantic incorrectness during the pass.
|
||||
Both approaches restore a correct graph once all patterns are matched.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Get min_token_num threshold
|
||||
# Read min_token_num from config (calculated during config init)
|
||||
self.min_token_num = None
|
||||
if config.model_config is not None:
|
||||
pass_config = config.compilation_config.pass_config
|
||||
self.min_token_num = pass_config.sp_min_token_num
|
||||
|
||||
if self.min_token_num is not None:
|
||||
# Take the min to avoid exceeding max_num_batched_tokens
|
||||
max_batched = config.scheduler_config.max_num_batched_tokens
|
||||
if max_batched is not None:
|
||||
self.min_token_num = min(self.min_token_num, max_batched)
|
||||
logger.debug_once(
|
||||
f"Sequence parallelism min token threshold: {self.min_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
# Used to clean up redundant views created temporarily
|
||||
# to circumvent residual shape change issues
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Determines if sequence parallelism should be applied for the given
|
||||
compile range.
|
||||
|
||||
SP is only beneficial for larger batch sizes where the communication
|
||||
overhead is amortized. For small batches, the overhead of splitting
|
||||
and gathering tensors across TP ranks outweighs the benefits.
|
||||
|
||||
Returns False (SP disabled) when:
|
||||
- Using piecewise compilation with non-concrete or TP-indivisible sizes
|
||||
- min_token_num is None (SP disabled for this device/config)
|
||||
- The compile range starts below the minimum token threshold
|
||||
"""
|
||||
# For piecewise compilation (not using inductor graph partition),
|
||||
# we need concrete sizes that are divisible by TP for correct splitting
|
||||
if (
|
||||
not self.compilation_config.use_inductor_graph_partition
|
||||
and self.compilation_config.splitting_ops
|
||||
):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
|
||||
return False
|
||||
|
||||
# min_token_num is None when SP is disabled for this device/config
|
||||
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
|
||||
if self.min_token_num is None:
|
||||
return False
|
||||
|
||||
# Only apply SP when batch size meets the minimum threshold
|
||||
return compile_range.start >= self.min_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
# Clean up reshape nodes
|
||||
self.noop_cleanup(graph)
|
||||
77
vllm/compilation/passes/fx_utils.py
Normal file
77
vllm/compilation/passes/fx_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch.fx.node import Target
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target: Target) -> bool:
|
||||
return bool(node.op == "call_function" and node.target == target)
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
# Also handles op overload packets and finds all overloads
|
||||
def find_op_nodes(
|
||||
op: OpOverload | OpOverloadPacket, graph: fx.Graph
|
||||
) -> Iterator[fx.Node]:
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
for overload in op.overloads():
|
||||
overload_op = getattr(op, overload)
|
||||
yield from find_op_nodes(overload_op, graph)
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
134
vllm/compilation/passes/inductor_pass.py
Normal file
134
vllm/compilation/passes/inductor_pass.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.utils import Range
|
||||
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
|
||||
_pass_context = None
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class PassContext:
|
||||
def __init__(self, compile_range: Range):
|
||||
self.compile_range: Range = compile_range
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(compile_range: Range) -> Generator[None, None, None]:
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(compile_range)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: str | Any) -> str:
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, (types.FunctionType, type)):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
# object instance
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]) -> str:
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
|
||||
) -> None:
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return fn_new
|
||||
178
vllm/compilation/passes/pass_manager.py
Normal file
178
vllm/compilation/passes/pass_manager.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from .fusion.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fusion.act_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion.attn_quant_fusion import AttnFusionPass
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from .utility.scatter_split_replace import ScatterSplitReplacementPass
|
||||
from .utility.split_coalescing import SplitCoalescingPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
from .fusion.collective_fusion import AsyncTPPass
|
||||
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from .utility.fix_functionalization import FixFunctionalizationPass
|
||||
from .utility.noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
Used to avoid logging builtin Inductor pattern matching.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
compile_range = get_pass_context().compile_range
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
else:
|
||||
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig) -> None:
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
with set_current_vllm_config(config, check_compile=False):
|
||||
if self.pass_config.eliminate_noops:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sp:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.fuse_gemm_comms:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.fuse_allreduce_rms:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [
|
||||
RocmAiterRMSNormQuantFusionPass(config),
|
||||
]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [ScatterSplitReplacementPass(config)]
|
||||
self.passes += [RopeKVCacheFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass) -> None:
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
passes = []
|
||||
|
||||
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
|
||||
for pass_ in self.passes:
|
||||
passes.append(pass_.uuid())
|
||||
passes.append(self.fix_functionalization.uuid())
|
||||
|
||||
# Include the compile range in the uuid to ensure that inductor
|
||||
# recompiles the graph for the new dynamic compile range.
|
||||
state["compile_range"] = str(get_pass_context().compile_range)
|
||||
state["passes"] = passes
|
||||
return InductorPass.hash_dict(state)
|
||||
0
vllm/compilation/passes/utility/__init__.py
Normal file
0
vllm/compilation/passes/utility/__init__.py
Normal file
301
vllm/compilation/passes/utility/fix_functionalization.py
Normal file
301
vllm/compilation/passes/utility/fix_functionalization.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug(
|
||||
"XPU platform does not support fix functionalizationpass currently."
|
||||
)
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of an mm_node.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into mm_node. So after de-functionalization, we can
|
||||
# just use mm_node directly.
|
||||
|
||||
mm_node = query.args[0].args[0]
|
||||
for user in getitem_nodes.values():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(
|
||||
user_of_getitem, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
else:
|
||||
# Directly replace the auto_functionalize(rotary_embedding)
|
||||
# with the inplace rotary_embedding. In theory, we shouldn't
|
||||
# do this blindly, but in practice in vLLM it's ok. The best
|
||||
# solution is to use auto_functionalization_v2 and then use
|
||||
# inductor's builtin defunctionalization (reinplacing) pass.
|
||||
mutated_args = {1: "query", 2: "key"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: "input", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "scale", 3: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
|
||||
and at_target
|
||||
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "allreduce_in",
|
||||
2: "residual",
|
||||
3: "norm_out",
|
||||
4: "quant_out",
|
||||
5: "scale_out",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input")
|
||||
)
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input", "scale")
|
||||
)
|
||||
elif (
|
||||
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
|
||||
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
|
||||
):
|
||||
mutated_args = {1: "result", 2: "result_block_scale"}
|
||||
self.defunctionalize(
|
||||
graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=(
|
||||
"result",
|
||||
"result_block_scale",
|
||||
"input",
|
||||
"input_global_scale",
|
||||
),
|
||||
)
|
||||
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
|
||||
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
|
||||
mutated_args = {1: "qkv"}
|
||||
args = (
|
||||
"qkv",
|
||||
"num_heads_q",
|
||||
"num_heads_k",
|
||||
"num_heads_v",
|
||||
"head_dim",
|
||||
"eps",
|
||||
"q_weight",
|
||||
"k_weight",
|
||||
"cos_sin_cache",
|
||||
"is_neox",
|
||||
"position_ids",
|
||||
)
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
|
||||
and at_target
|
||||
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "query",
|
||||
2: "key",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
|
||||
and at_target
|
||||
== torch.ops.vllm.function_with_mutated_args_and_return.default
|
||||
):
|
||||
mutated_args = {1: "x"}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug(
|
||||
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||
)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, torch.fx.Node | str],
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
|
||||
) -> None:
|
||||
"""
|
||||
Replace mutated getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
# Some functionalized nodes may return both a result at getitem[0]
|
||||
# as well as mutated args at getitem[1:...]
|
||||
if idx == 0:
|
||||
assert idx not in mutated_args, (
|
||||
f"result at getitem[0] should not be in mutated_args for {node}"
|
||||
)
|
||||
continue
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), (
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
)
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
fn_node = graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(
|
||||
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||
)
|
||||
fn_node = graph.call_function(function, args=args)
|
||||
|
||||
# If the function returns a value as well as mutating args inplace,
|
||||
# the functionalized node will have a getitem[0] user that holds this value
|
||||
# Replace getitem[0] user of the auto-functionalized node
|
||||
# with the new defunctionalized node directly if it exists
|
||||
users = self.getitem_users(node)
|
||||
if 0 in users:
|
||||
user = users[0]
|
||||
user.replace_all_uses_with(fn_node)
|
||||
self._remove(user)
|
||||
130
vllm/compilation/passes/utility/noop_elimination.py
Normal file
130
vllm/compilation/passes/utility/noop_elimination.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape/slice operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Cases handled:
|
||||
1. A chain of reshapes is equivalent to the last reshape called on the
|
||||
base tensor (input of the first reshape).
|
||||
2. A reshape that produces the shape of the input is redundant
|
||||
3. A slice that produces the shape of the input is redundant
|
||||
|
||||
Example graph 1:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
|
||||
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
|
||||
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
|
||||
|
||||
Can be replaced with:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_3: "f16[s0, 128, 32]" = ...
|
||||
|
||||
Example graph 2:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 3:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||
|
||||
Can be replaced with:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
# Case 1: rewrite reshape chains to reshapes on the base tensor
|
||||
input = node.args[0]
|
||||
# If the input is a reshape, rebind to that node
|
||||
if is_func(input, torch.ops.aten.reshape.default):
|
||||
# The new input is guaranteed not to be a reshape,
|
||||
# because we process nodes in order
|
||||
node.update_arg(0, input.args[0])
|
||||
if len(input.users) == 0:
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# remove reshape/slice if it produces the original shape
|
||||
if is_func(node, torch.ops.aten.reshape.default) or is_func(
|
||||
node, torch.ops.aten.slice.Tensor
|
||||
):
|
||||
input = node.args[0]
|
||||
input_shape = input.meta["val"].shape
|
||||
output_shape = node.meta["val"].shape
|
||||
if self.all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
if self.all_dims_equivalent(base_shape, view_shape):
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Shape comparison helpers ----------------------
|
||||
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are two cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The dimensions both correspond to the same SymInt
|
||||
"""
|
||||
# Case 1
|
||||
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
|
||||
|
||||
def all_dims_equivalent(
|
||||
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
|
||||
) -> bool:
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
# Different ranks can't be equivalent
|
||||
return False
|
||||
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
21
vllm/compilation/passes/utility/post_cleanup.py
Normal file
21
vllm/compilation/passes/utility/post_cleanup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
"""
|
||||
This pass performs cleanup after custom passes.
|
||||
It topologically sorts the graph and removes unused nodes.
|
||||
This is needed because the pattern matcher does not guarantee producing
|
||||
a topologically sorted graph, and there may be unused nodes left around.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
|
||||
assignment if there are no users for the inplace tensor written to by
|
||||
the slice_scatter call.
|
||||
|
||||
The inplace rotary_embedding custom op takes in mutable query and key inputs
|
||||
that are split+getitem outputs of a single qkv tensor.
|
||||
When functionalized, we fetch the rotated query and key from the functionalized op
|
||||
using `getitem` calls. However, we also write to the qkv tensor inplace using a
|
||||
`slice_scatter`, then split the inplace tensor to get the output tensors again.
|
||||
Instead, if the inplace tensor has no subsequent users, we can just replace the
|
||||
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
|
||||
|
||||
This is already done in fix_functionalization::FixFunctionalizationPass, but
|
||||
writing a custom pass for it before defunctionalization allows matching against the
|
||||
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ScatterSplitReplacementPass(VllmInductorPass):
|
||||
"""Replace getitem+slice_scatter+split nodes with a single getitem when
|
||||
the inplace subtensor written to by the slice_scatter has no other users.
|
||||
|
||||
Here's an example graph with q_size = 512, kv_size = 64:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
|
||||
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
|
||||
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
q = operator.getitem(split_with_sizes_2, 0)
|
||||
k = operator.getitem(split_with_sizes_2, 1)
|
||||
v = operator.getitem(split_with_sizes_2, 2)
|
||||
|
||||
After this pass, this sequence of nodes is replaced with:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
v = operator.getitem(split_with_sizes_1, 2)
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
target_ops = [torch.ops._C.rotary_embedding.default]
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target in target_ops:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
getitem_nodes[user.args[1]] = user
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of a qkv tensor.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into qkv, then split again to get query and key.
|
||||
# If the inplace tensor has no other users, we can replace
|
||||
# the slice_scatter+split nodes with the original results.
|
||||
for user in getitem_nodes[1].users:
|
||||
slice_scatter_1_node = user
|
||||
if not is_func(
|
||||
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in getitem_nodes[2].users:
|
||||
slice_scatter_2_node = user
|
||||
if not is_func(
|
||||
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in slice_scatter_2_node.users:
|
||||
split_node = user
|
||||
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
|
||||
split_getitem_users = {}
|
||||
for user in split_node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
split_getitem_users[user.args[1]] = user
|
||||
|
||||
# Replace query node
|
||||
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
|
||||
graph.erase_node(split_getitem_users[0])
|
||||
# Replace key node
|
||||
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
|
||||
graph.erase_node(split_getitem_users[1])
|
||||
# Redirect value node to original qkv tensor
|
||||
split_getitem_users[2].replace_input_with(split_node, query.args[0])
|
||||
|
||||
# Erase unused nodes
|
||||
graph.erase_node(split_node)
|
||||
graph.erase_node(slice_scatter_2_node)
|
||||
graph.erase_node(slice_scatter_1_node)
|
||||
|
||||
count += 1
|
||||
|
||||
logger.debug("Eliminated %d slice_scatter+split nodes", count)
|
||||
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
|
||||
input tensor with the same split sizes.
|
||||
|
||||
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
|
||||
graph may contain multiple ``split_with_sizes`` calls on the same tensor
|
||||
that CSE fails to merge. This pass detects and replaces the duplicates
|
||||
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
|
||||
see a single split node with all users attached.
|
||||
|
||||
See also:
|
||||
- vLLM #33295 (original issue)
|
||||
- PyTorch #174472 (upstream CSE gap)
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SplitCoalescingPass(VllmInductorPass):
|
||||
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
|
||||
node when they share the same input tensor and split sizes."""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
# Map from input tensor node -> list of split nodes seen so far.
|
||||
split_nodes: dict[fx.Node, list[fx.Node]] = {}
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
if not all(is_func(user, operator.getitem) for user in node.users):
|
||||
continue
|
||||
|
||||
arg_node, split_sizes = node.args[:2]
|
||||
|
||||
if arg_node not in split_nodes:
|
||||
split_nodes[arg_node] = [node]
|
||||
continue
|
||||
|
||||
# Find existing node with same split_sizes
|
||||
canonical = next(
|
||||
(
|
||||
n
|
||||
for n in split_nodes[arg_node]
|
||||
if list(n.args[1]) == list(split_sizes)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if canonical is not None:
|
||||
node.replace_all_uses_with(canonical)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
else:
|
||||
split_nodes[arg_node].append(node)
|
||||
|
||||
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|
||||
180
vllm/compilation/passes/vllm_inductor_pass.py
Normal file
180
vllm/compilation/passes/vllm_inductor_pass.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inductor_pass import InductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InductorCompilationConfig:
|
||||
splitting_ops: list[str] | None = None
|
||||
use_inductor_graph_partition: bool = False
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
|
||||
dump_prefix: ClassVar[int | None] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
# Get only the necessary CompilationConfig for the inductor pass, since
|
||||
# full `CompilationConfig` contains pointer to model which is unsafe.
|
||||
self.compilation_config = InductorCompilationConfig(
|
||||
splitting_ops=config.compilation_config.splitting_ops,
|
||||
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
|
||||
)
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device: str | None = (
|
||||
config.device_config.device if config.device_config else None
|
||||
)
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(
|
||||
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
|
||||
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
self.dump_graph(graph, "after")
|
||||
self.end_and_log()
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(
|
||||
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
|
||||
)
|
||||
|
||||
def begin(self) -> None:
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self) -> None:
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class VllmPatternMatcherPass(VllmInductorPass):
|
||||
"""
|
||||
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||
Its main use is providing the dump_patterns utility that dumps the
|
||||
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
|
||||
)
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return str(
|
||||
self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
|
||||
This method does its best to print something that looks like Python code
|
||||
for easier debugging and potentially navigation. If any errors appear in
|
||||
the output, please add to this method.
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compile_debug_dump_path()
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils.system_utils import unique_filepath
|
||||
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
|
||||
)
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f"# This file was produced by VllmPatternMatcherPass."
|
||||
f"dump_patterns for {self.pass_name}.\n"
|
||||
f"# It does its best to produce valid-Python-looking code but"
|
||||
f" please add to dump_patterns if there are any errors.\n\n"
|
||||
f"from torch._higher_order_ops.auto_functionalize import "
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *\n"
|
||||
f"vllm = torch.ops.vllm",
|
||||
file=f,
|
||||
)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
if node[1] == operator.getitem:
|
||||
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||
else:
|
||||
node_repr = repr(node)
|
||||
|
||||
node_repr = self._replace_op_overloads(node_repr)
|
||||
|
||||
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||
for i, pattern in enumerate(patterns):
|
||||
# reserve auto_functionalized ahead of time
|
||||
pp = PatternPrettyPrinter()
|
||||
pp.namespace.create_name("auto_functionalized", None)
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join(
|
||||
[f"def pattern_{i}():"]
|
||||
+ [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
]
|
||||
+ [f"return {out_node}"]
|
||||
).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
def __init__(self, name: str, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.dump_graph(graph, self.name)
|
||||
343
vllm/compilation/piecewise_backend.py
Normal file
343
vllm/compilation/piecewise_backend.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pickle import Pickler
|
||||
from typing import Any
|
||||
|
||||
import torch._functorch.config
|
||||
import torch.fx as fx
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RangeEntry:
|
||||
compile_range: Range
|
||||
compiled: bool = False
|
||||
runnable: Callable[..., Any] = None # type: ignore
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule | None,
|
||||
vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
vllm_backend: VllmBackend,
|
||||
returns_tuple: bool,
|
||||
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
|
||||
submod_name: str = "",
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
This class supports two mutually exclusive modes:
|
||||
1. Compilation (graph is set, compiled_runnables is None):
|
||||
Used during initial compilation when we have the FX graph
|
||||
and need to compile it for each shape range.
|
||||
2. Precompilation (graph is None, compiled_runnables is set):
|
||||
Used when loading from cache/AOT artifacts where we already
|
||||
have pre-compiled callables and don't need the original graph.
|
||||
|
||||
Exactly one of graph or compiled_runnables must be provided.
|
||||
"""
|
||||
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
|
||||
"exactly one of graph and compiled_runnables should be set."
|
||||
)
|
||||
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
self.compiled_runnables = compiled_runnables
|
||||
self.submod_name = submod_name
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
self.is_encoder_compilation = vllm_backend.is_encoder
|
||||
|
||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||
if self.is_encoder_compilation:
|
||||
# For encoder compilation we use the max int32 value
|
||||
# to set the upper bound of the compile ranges
|
||||
max_int32 = 2**31 - 1
|
||||
last_compile_range = self.compile_ranges[-1]
|
||||
assert (
|
||||
last_compile_range.end
|
||||
== vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.compile_ranges[-1] = Range(
|
||||
start=last_compile_range.start, end=max_int32
|
||||
)
|
||||
|
||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.compile_sizes = self.compilation_config.compile_sizes
|
||||
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
self.returns_tuple = returns_tuple
|
||||
|
||||
# the entries for ranges that we need to either
|
||||
self.range_entries: dict[Range, RangeEntry] = {}
|
||||
|
||||
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
if self.compile_sizes is not None:
|
||||
for size in self.compile_sizes:
|
||||
if isinstance(size, str):
|
||||
assert size == "cudagraph_capture_sizes"
|
||||
raise NotImplementedError(
|
||||
"cudagraph_capture_sizes not supported in compile_sizes."
|
||||
"This should be handled in `post_init_cudagraph_sizes`."
|
||||
)
|
||||
else:
|
||||
assert isinstance(size, int)
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
|
||||
for range in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
# Track whether we've logged the graph for this subgraph (only log once)
|
||||
self._graph_logged = False
|
||||
|
||||
# get the on_compilation_complete callback from context...
|
||||
# PiecewiseBackend is created during the first call,
|
||||
# which is when the context is set (see compilation/decorators.py)
|
||||
from vllm.compilation.backends import _on_compilation_complete_callback
|
||||
|
||||
self.on_compilation_complete = _on_compilation_complete_callback.get()
|
||||
|
||||
def get_compiled_graph_wrapper(
|
||||
self, compiled_graph: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
def compiled_graph_wrapper(*args: Any) -> Any:
|
||||
graph_output = compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
def check_for_ending_compilation(self) -> None:
|
||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
time_before_saving = time.perf_counter()
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
elapsed = time.perf_counter() - time_before_saving
|
||||
if elapsed > 1:
|
||||
logger.info_once(
|
||||
"Saved compiler manager cache in %.2f seconds.",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
# Call the completion callback (e.g., to save AOT compiled function)
|
||||
if self.on_compilation_complete is not None:
|
||||
self.on_compilation_complete()
|
||||
|
||||
def to_bytes(self) -> dict[str, bytes]:
|
||||
class StandaloneCompiledArtifactsPickler(Pickler):
|
||||
def reducer_override(self, obj: object) -> Any:
|
||||
if isinstance(obj, CachingAutotuner):
|
||||
obj.prepare_for_pickle()
|
||||
return pickle.loads, (
|
||||
pickle.dumps(
|
||||
obj,
|
||||
),
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def serialize(fn: Callable[..., Any]) -> bytes:
|
||||
assert hasattr(fn, "serialize"), "fn must have serialize method"
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
entry = fn.serialize()
|
||||
|
||||
f = io.BytesIO()
|
||||
StandaloneCompiledArtifactsPickler(f).dump(entry)
|
||||
result = f.getvalue()
|
||||
return result
|
||||
|
||||
out = {}
|
||||
|
||||
for range_key, entry in self.range_entries.items():
|
||||
if not entry.compiled:
|
||||
logger.debug(
|
||||
"entry with range %s not compiled, so cannot get its bytes",
|
||||
range_key,
|
||||
)
|
||||
continue
|
||||
if hasattr(entry.runnable, "serialize"):
|
||||
out[str(range_key)] = serialize(entry.runnable)
|
||||
|
||||
return out
|
||||
|
||||
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
|
||||
# We need to pass fake example_inputs, otherwise torch.compile
|
||||
# will fakify the example_inputs potentially causing some non dynamic
|
||||
# dimension to be be duck shaped to other existing shapes that have hints
|
||||
# matching their values.
|
||||
# This is problem because it can lead to unintended specializations!
|
||||
# if the new wrongly dynamic dim is specialized
|
||||
# it will force specializing the whole shape
|
||||
# torch.compile probably should not accept
|
||||
# non fake tensors as example inputs!
|
||||
# See issue https://github.com/vllm-project/vllm/issues/27899
|
||||
fake_example_inputs = []
|
||||
assert self.graph is not None
|
||||
for node in self.graph.graph.nodes:
|
||||
# All place holders come first
|
||||
if node.op == "placeholder":
|
||||
fake_example_inputs.append(node.meta["example_value"])
|
||||
else:
|
||||
break
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
def _log_compile_start(self, compile_range: Range):
|
||||
"""Log compilation event for TORCH_TRACE/tlparse."""
|
||||
is_cudagraph_size = (
|
||||
self.compile_sizes is not None and compile_range.start in self.compile_sizes
|
||||
)
|
||||
subgraph_index = self.piecewise_compile_index
|
||||
submod_name = self.submod_name
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "vllm_piecewise_compile_start",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
{
|
||||
"piecewise_index": subgraph_index,
|
||||
"submod_name": submod_name,
|
||||
"total_piecewise_compiles": self.total_piecewise_compiles,
|
||||
"compile_range_start": compile_range.start,
|
||||
"compile_range_end": compile_range.end,
|
||||
"is_single_size": compile_range.is_single_size(),
|
||||
"is_cudagraph_capture_size": is_cudagraph_size,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Log the subgraph graph dump only once per subgraph (not per size)
|
||||
# to reduce log file size. The graph code is the same for all sizes.
|
||||
if not self._graph_logged:
|
||||
self._graph_logged = True
|
||||
assert self.graph is not None
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
metadata_fn=lambda: {
|
||||
"name": f"vllm_{submod_name}",
|
||||
},
|
||||
payload_fn=lambda: self.graph.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
def _maybe_compile_for_range_entry(
|
||||
self, range_entry: RangeEntry, args: tuple[Any, ...]
|
||||
) -> Any:
|
||||
if not range_entry.compiled:
|
||||
if self.compiled_runnables is not None:
|
||||
range_entry.runnable = self.get_compiled_graph_wrapper(
|
||||
self.compiled_runnables[str(range_entry.compile_range)]
|
||||
)
|
||||
else:
|
||||
self._log_compile_start(range_entry.compile_range)
|
||||
|
||||
# args are real arguments
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args_list = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else list(args)
|
||||
)
|
||||
|
||||
with (
|
||||
torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
):
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
)
|
||||
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
|
||||
# First we try to find the range entry for the concrete compile size
|
||||
# If not found, we search for the range entry
|
||||
# that contains the runtime shape.
|
||||
if self.compile_sizes is None:
|
||||
return None
|
||||
|
||||
if runtime_shape in self.compile_sizes:
|
||||
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
|
||||
else:
|
||||
for range in self.compile_ranges:
|
||||
if runtime_shape in range:
|
||||
return self.range_entries[range]
|
||||
return None
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
range_entry = self._find_range_for_shape(runtime_shape)
|
||||
|
||||
assert range_entry is not None, (
|
||||
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
|
||||
)
|
||||
|
||||
self._maybe_compile_for_range_entry(range_entry, args)
|
||||
return range_entry.runnable(*args)
|
||||
321
vllm/compilation/wrapper.py
Normal file
321
vllm/compilation/wrapper.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import CodeType
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _noop_add_global_state_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
def _noop_add_torch_function_mode_stack_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compilation_context() -> Generator[None, None, None]:
|
||||
"""Context manager for compilation settings and patches.
|
||||
|
||||
This manager:
|
||||
1. Sets higher dynamo cache limits for compilation. (Needed for
|
||||
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||
Generally a recompilation can happen whenever we use a new
|
||||
backend instance in torch.compile.
|
||||
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
||||
3. Patches out add_torch_function_mode_stack_guard to skip
|
||||
TORCH_FUNCTION_MODE_STACK guards.
|
||||
4. Restores everything when compilation completes
|
||||
"""
|
||||
# Save original values
|
||||
original_global_state_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
||||
)
|
||||
original_torch_function_mode_stack_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
||||
)
|
||||
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||
|
||||
try:
|
||||
# Set higher cache limits for compilation
|
||||
torch._dynamo.config.cache_size_limit = 2048
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||
|
||||
# Patch guard manager
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
_noop_add_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
_noop_add_torch_function_mode_stack_guard
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original values
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
original_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
original_torch_function_mode_stack_guard
|
||||
)
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||
|
||||
|
||||
class TorchCompileWithNoGuardsWrapper:
|
||||
"""
|
||||
A wrapper class for torch.compile, it ensures that all guards are dropped
|
||||
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
||||
When guards are dropped, the first time __call__ is invoked, a single
|
||||
compilation is triggered. Dynamo should never be traced again after that
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
assert hasattr(self, "_check_shape_invariants")
|
||||
self._check_shape_invariants(*args, **kwargs)
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _call_with_optional_nvtx_range(
|
||||
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> Any:
|
||||
if self.layerwise_nvtx_tracing_enabled:
|
||||
args_list = list(args)
|
||||
kwargs_dict = dict(kwargs)
|
||||
with layerwise_nvtx_marker_context(
|
||||
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
|
||||
self,
|
||||
in_tensor=args_list,
|
||||
kwargs=kwargs_dict,
|
||||
) as ctx:
|
||||
ctx.result = callable_fn(*args, **kwargs)
|
||||
return ctx.result
|
||||
return callable_fn(*args, **kwargs)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
mode = vllm_config.compilation_config.mode
|
||||
self.layerwise_nvtx_tracing_enabled = (
|
||||
vllm_config.observability_config.enable_layerwise_nvtx_tracing
|
||||
)
|
||||
if mode is None:
|
||||
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = {}
|
||||
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
self.first_compile = True
|
||||
self.evaluate_guards = (
|
||||
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||
)
|
||||
|
||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
if self.evaluate_guards:
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"compilation_config.dynamic_shapes_config.evaluate_guards "
|
||||
"requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
|
||||
options["guard_filter_fn"] = lambda x: [
|
||||
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||
]
|
||||
else:
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
compiled_ptr: Any = self.forward
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
# reason is that bytecode does torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation. And if we use
|
||||
# compiled_ptr = self.check_invariants_and_forward
|
||||
# it will reset all entries.
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
aot_context = nullcontext()
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
with aot_context:
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
self._compiled_bytecode: CodeType | None = None
|
||||
|
||||
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||
raise RuntimeError(
|
||||
"aot_compile is not supported by the current configuration. "
|
||||
"Please make sure torch.compile is enabled with the latest "
|
||||
f"version of PyTorch (current using torch: {torch.__version__})"
|
||||
)
|
||||
return self._compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
if not self._compiled_bytecode:
|
||||
# Make sure a compilation is triggered by clearing dynamo
|
||||
# cache.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
with self._dispatch_to_compiled_code():
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self.forward, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if self.first_compile or not self.evaluate_guards
|
||||
else torch.compiler.set_stance("fail_on_recompile")
|
||||
)
|
||||
self.first_compile = False
|
||||
with _compilation_context(), ctx:
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def original_code_object(self) -> CodeType:
|
||||
"""Return the original code object of the forward method."""
|
||||
return self.__class__.forward.__code__
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object():
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self._compiled_bytecode = new_code
|
||||
|
||||
path = self.vllm_config.compile_debug_dump_path()
|
||||
if path:
|
||||
decompiled_file = path / "transformed_code.py"
|
||||
if not decompiled_file.exists():
|
||||
try:
|
||||
# usually the decompilation will succeed for most models,
|
||||
# as we guarantee a full-graph compilation in Dynamo.
|
||||
# but there's no 100% guarantee, since decompliation is
|
||||
# not a reversible process.
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
|
||||
with open(decompiled_file, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and "update" in new_code.co_names
|
||||
):
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
msg = (
|
||||
"Assigning / modifying buffers of nn.Module during forward pass is not "
|
||||
"allowed when using cudagraph inside the compiler because it will "
|
||||
"cause silent errors. Please use eager mode or fix the code. The "
|
||||
"following code contains clues about which buffer is being modified "
|
||||
f"(please search for the usage of the function `update`):\n{src}"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
|
||||
# noqa: E501
|
||||
"""
|
||||
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa: E501 line too long
|
||||
original = self.original_code_object()
|
||||
assert self._compiled_bytecode is not None
|
||||
self.__class__.forward.__code__ = self._compiled_bytecode
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
130
vllm/config/__init__.py
Normal file
130
vllm/config/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.attention import AttentionConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.compilation import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
PassConfig,
|
||||
)
|
||||
from vllm.config.device import DeviceConfig
|
||||
from vllm.config.ec_transfer import ECTransferConfig
|
||||
from vllm.config.kernel import KernelConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.config.model import (
|
||||
ModelConfig,
|
||||
iter_architecture_defaults,
|
||||
str_dtype_to_torch_dtype,
|
||||
try_match_architecture_defaults,
|
||||
)
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config.observability import ObservabilityConfig
|
||||
from vllm.config.offload import (
|
||||
OffloadBackend,
|
||||
OffloadConfig,
|
||||
PrefetchOffloadConfig,
|
||||
UVAOffloadConfig,
|
||||
)
|
||||
from vllm.config.parallel import EPLBConfig, ParallelConfig
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.profiler import ProfilerConfig
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
from vllm.config.utils import (
|
||||
ConfigType,
|
||||
SupportsMetricsInfo,
|
||||
config,
|
||||
get_attr_docs,
|
||||
is_init_field,
|
||||
replace,
|
||||
update_config,
|
||||
)
|
||||
from vllm.config.vllm import (
|
||||
VllmConfig,
|
||||
get_cached_compilation_config,
|
||||
get_current_vllm_config,
|
||||
get_current_vllm_config_or_none,
|
||||
get_layers_from_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
|
||||
# __all__ should only contain classes and functions.
|
||||
# Types and globals should be imported from their respective modules.
|
||||
__all__ = [
|
||||
# From vllm.config.attention
|
||||
"AttentionConfig",
|
||||
# From vllm.config.cache
|
||||
"CacheConfig",
|
||||
# From vllm.config.compilation
|
||||
"CompilationConfig",
|
||||
"CompilationMode",
|
||||
"CUDAGraphMode",
|
||||
"PassConfig",
|
||||
# From vllm.config.device
|
||||
"DeviceConfig",
|
||||
# From vllm.config.ec_transfer
|
||||
"ECTransferConfig",
|
||||
# From vllm.config.kernel
|
||||
"KernelConfig",
|
||||
# From vllm.config.kv_events
|
||||
"KVEventsConfig",
|
||||
# From vllm.config.kv_transfer
|
||||
"KVTransferConfig",
|
||||
# From vllm.config.load
|
||||
"LoadConfig",
|
||||
# From vllm.config.lora
|
||||
"LoRAConfig",
|
||||
# From vllm.config.model
|
||||
"ModelConfig",
|
||||
"iter_architecture_defaults",
|
||||
"str_dtype_to_torch_dtype",
|
||||
"try_match_architecture_defaults",
|
||||
# From vllm.config.multimodal
|
||||
"MultiModalConfig",
|
||||
# From vllm.config.observability
|
||||
"ObservabilityConfig",
|
||||
# From vllm.config.offload
|
||||
"OffloadBackend",
|
||||
"OffloadConfig",
|
||||
"PrefetchOffloadConfig",
|
||||
"UVAOffloadConfig",
|
||||
# From vllm.config.parallel
|
||||
"EPLBConfig",
|
||||
"ParallelConfig",
|
||||
# From vllm.config.pooler
|
||||
"PoolerConfig",
|
||||
# From vllm.config.scheduler
|
||||
"SchedulerConfig",
|
||||
# From vllm.config.speculative
|
||||
"SpeculativeConfig",
|
||||
# From vllm.config.speech_to_text
|
||||
"SpeechToTextConfig",
|
||||
# From vllm.config.structured_outputs
|
||||
"StructuredOutputsConfig",
|
||||
# From vllm.config.profiler
|
||||
"ProfilerConfig",
|
||||
# From vllm.config.utils
|
||||
"ConfigType",
|
||||
"SupportsMetricsInfo",
|
||||
"config",
|
||||
"get_attr_docs",
|
||||
"is_init_field",
|
||||
"replace",
|
||||
"update_config",
|
||||
# From vllm.config.vllm
|
||||
"VllmConfig",
|
||||
"get_cached_compilation_config",
|
||||
"get_current_vllm_config",
|
||||
"get_current_vllm_config_or_none",
|
||||
"set_current_vllm_config",
|
||||
"get_layers_from_vllm_config",
|
||||
"WeightTransferConfig",
|
||||
]
|
||||
69
vllm/config/attention.py
Normal file
69
vllm/config/attention.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
@config
|
||||
class AttentionConfig:
|
||||
"""Configuration for attention mechanisms in vLLM."""
|
||||
|
||||
backend: AttentionBackendEnum | None = None
|
||||
"""Attention backend to use. If None, will be selected automatically."""
|
||||
|
||||
flash_attn_version: Literal[2, 3] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2 or 3).
|
||||
Only valid when using the flash-attention backend."""
|
||||
|
||||
use_prefill_decode_attention: bool = False
|
||||
"""Use separate prefill and decode kernels for attention instead of
|
||||
the unified triton kernel."""
|
||||
|
||||
flash_attn_max_num_splits_for_cuda_graph: int = 32
|
||||
"""Flash Attention max number splits for cuda graph decode."""
|
||||
|
||||
use_cudnn_prefill: bool = False
|
||||
"""Whether to use cudnn prefill."""
|
||||
|
||||
use_trtllm_ragged_deepseek_prefill: bool = True
|
||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||
|
||||
use_trtllm_attention: bool | None = None
|
||||
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||
|
||||
disable_flashinfer_prefill: bool = False
|
||||
"""Whether to disable flashinfer prefill."""
|
||||
|
||||
disable_flashinfer_q_quantization: bool = False
|
||||
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||
|
||||
use_prefill_query_quantization: bool = False
|
||||
"""If set, quantize query for attention in prefill."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
ignored_factors: list[str] = []
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("backend", mode="before")
|
||||
@classmethod
|
||||
def validate_backend_before(cls, value: Any) -> Any:
|
||||
"""Enable parsing of the `backend` enum type from string."""
|
||||
if isinstance(value, str):
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
return value
|
||||
250
vllm/config/cache.py
Normal file
250
vllm/config/cache.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import Field, SkipValidation, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import format_gib, get_cpu_memory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
else:
|
||||
ParallelConfig = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
|
||||
CacheDType = Literal[
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
"fp8_inc",
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
MambaDType = Literal["auto", "float32", "float16"]
|
||||
MambaCacheMode = Literal["all", "align", "none"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
|
||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
|
||||
@config
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
|
||||
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
|
||||
only block sizes up to 32 are supported.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `Platform.check_and_update_config()` based on the current
|
||||
platform."""
|
||||
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
|
||||
"""The fraction of GPU memory to be used for the model executor, which can
|
||||
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||
utilization. If unspecified, will use the default value of 0.9. This is a
|
||||
per-instance limit, and only applies to the current vLLM instance. It does
|
||||
not matter if you have another vLLM instance running on the same GPU. For
|
||||
example, if you have two vLLM instances running on the same GPU, you can
|
||||
set the GPU memory utilization to 0.5 for each instance."""
|
||||
swap_space: float = Field(default=4, ge=0)
|
||||
"""Size of the CPU swap space per GPU (in GiB)."""
|
||||
cache_dtype: CacheDType = "auto"
|
||||
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
|
||||
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
|
||||
bfloat16 instead, this is an invalid option for models that do not default
|
||||
to fp8.
|
||||
"""
|
||||
is_attention_free: bool = False
|
||||
"""Whether the model is attention-free. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
num_gpu_blocks_override: int | None = None
|
||||
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
|
||||
if specified. Does nothing if `None`. Used for testing preemption."""
|
||||
sliding_window: int | None = None
|
||||
"""Sliding window size for the KV cache. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
enable_prefix_caching: bool = True
|
||||
"""Whether to enable prefix caching."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "sha256" uses Pickle for object serialization before hashing. This is the
|
||||
current default, as SHA256 is the most secure choice to avoid potential
|
||||
hash collisions.\n
|
||||
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256.\n
|
||||
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
|
||||
non-cryptographic hashing. Requires the optional ``xxhash`` package.
|
||||
IMPORTANT: Use of a hashing algorithm that is not considered
|
||||
cryptographically secure theoretically increases the risk of hash collisions,
|
||||
which can cause undefined behavior or even leak private information in
|
||||
multi-tenant environments. Even if collisions are still very unlikely, it is
|
||||
important to consider your security risk tolerance against the performance
|
||||
benefits before turning this on.\n
|
||||
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
|
||||
reproducible hashing. Requires the optional ``xxhash`` package."""
|
||||
cpu_offload_gb: float = Field(default=0, ge=0)
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
increase the GPU memory size. For example, if you have one 24 GB GPU and
|
||||
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
|
||||
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
|
||||
Note that this requires fast CPU-GPU interconnect, as part of the model is
|
||||
loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||
|
||||
DEPRECATED: This field is deprecated and will be removed in v0.16.
|
||||
Please use OffloadConfig.uva.cpu_offload_gb instead.
|
||||
"""
|
||||
cpu_offload_params: set[str] = Field(default_factory=set)
|
||||
"""The set of parameter name segments to target for CPU offloading.
|
||||
|
||||
DEPRECATED: This field is deprecated and will be removed in v0.16.
|
||||
Please use OffloadConfig.uva.cpu_offload_params instead.
|
||||
"""
|
||||
calculate_kv_scales: bool = False
|
||||
"""This enables dynamic calculation of `k_scale` and `v_scale` when
|
||||
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
|
||||
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
||||
cpu_kvcache_space_bytes: int | None = None
|
||||
"""(CPU backend only) CPU key-value cache space."""
|
||||
mamba_page_size_padded: int | None = None
|
||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||
models to ensure exact alignment with attention page size."""
|
||||
mamba_block_size: int | None = Field(default=None, gt=0)
|
||||
"""Size of a contiguous cache block in number of tokens for mamba cache.
|
||||
Can be set only when prefix caching is enabled.
|
||||
Value must be a multiple of 8 to align with causal_conv1d kernel."""
|
||||
mamba_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||
config."""
|
||||
mamba_ssm_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (ssm state only, conv state will
|
||||
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
|
||||
for the ssm state will be determined by mamba_cache_dtype."""
|
||||
mamba_cache_mode: MambaCacheMode = "none"
|
||||
"""The cache strategy for Mamba layers.
|
||||
- "none": set when prefix caching is disabled.
|
||||
- "all": cache the mamba state of all tokens at position i * block_size. This is
|
||||
the default behavior (for models that support it) when prefix caching is
|
||||
enabled.
|
||||
- "align": only cache the mamba state of the last token of each scheduler step and
|
||||
when the token is at position i * block_size.
|
||||
"""
|
||||
|
||||
# Will be set after profiling.
|
||||
num_gpu_blocks: int | None = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for GPU memory."""
|
||||
num_cpu_blocks: int | None = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for CPU memory."""
|
||||
|
||||
kv_sharing_fast_prefill: bool = False
|
||||
"""This feature is work in progress and no prefill optimization takes place
|
||||
with this flag enabled currently.
|
||||
|
||||
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
|
||||
some layers can skip tokens corresponding to prefill. This flag enables
|
||||
attention metadata for eligible layers to be overridden with metadata
|
||||
necessary for implementing this optimization in some models (e.g. Gemma3n)
|
||||
"""
|
||||
|
||||
kv_cache_memory_bytes: int | None = None
|
||||
"""Size of KV Cache per GPU in bytes. By default, this is set to None
|
||||
and vllm can automatically infer the kv cache size based on
|
||||
gpu_memory_utilization. However, users may want to manually specify
|
||||
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
|
||||
control of how much memory gets used when compared with using
|
||||
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
||||
(when not-None) ignores gpu_memory_utilization"""
|
||||
|
||||
kv_offloading_size: float | None = None
|
||||
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
|
||||
the total buffer size summed across all TP ranks. By default, this is set
|
||||
to None, which means no KV offloading is enabled. When set, vLLM will
|
||||
enable KV cache offloading to CPU using the kv_offloading_backend."""
|
||||
|
||||
kv_offloading_backend: KVOffloadingBackend = "native"
|
||||
"""The backend to use for KV cache offloading. Supported backends include
|
||||
'native' (vLLM native CPU offloading), 'lmcache'.
|
||||
KV offloading is only activated when kv_offloading_size is set."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
ignored_factors = {
|
||||
# Runtime/derived knobs that don't affect compiled graph shape
|
||||
"gpu_memory_utilization",
|
||||
"swap_space",
|
||||
"is_attention_free",
|
||||
"num_gpu_blocks_override",
|
||||
"enable_prefix_caching",
|
||||
"prefix_caching_hash_algo",
|
||||
"cpu_kvcache_space_bytes",
|
||||
"mamba_page_size_padded",
|
||||
# Post-init/derived counters
|
||||
"num_gpu_blocks",
|
||||
"num_cpu_blocks",
|
||||
# WIP feature toggle not impacting compiled graph shape
|
||||
"kv_sharing_fast_prefill",
|
||||
}
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||
# metrics info
|
||||
return {key: str(value) for key, value in self.__dict__.items()}
|
||||
|
||||
@field_validator("cache_dtype", mode="after")
|
||||
@classmethod
|
||||
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
|
||||
if cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor."
|
||||
)
|
||||
return cache_dtype
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> None:
|
||||
swap_space_bytes = math.ceil(self.swap_space * GiB_bytes)
|
||||
total_cpu_memory = get_cpu_memory()
|
||||
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
|
||||
# group are in the same node. However, the GPUs may span multiple nodes.
|
||||
num_gpus_per_node = parallel_config.tensor_parallel_size
|
||||
cpu_memory_usage = swap_space_bytes * num_gpus_per_node
|
||||
|
||||
msg = (
|
||||
f"{format_gib(cpu_memory_usage)} GiB out of the "
|
||||
f"{format_gib(total_cpu_memory)} GiB total CPU memory "
|
||||
"is allocated for the swap space."
|
||||
)
|
||||
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
||||
raise ValueError("Too large swap space. " + msg)
|
||||
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||
logger.warning("Possibly too large swap space. %s", msg)
|
||||
1196
vllm/config/compilation.py
Normal file
1196
vllm/config/compilation.py
Normal file
File diff suppressed because it is too large
Load Diff
73
vllm/config/device.py
Normal file
73
vllm/config/device.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, SkipValidation
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: SkipValidation[Device | torch.device | None] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
on the current platform."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
`__post_init__`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_type = current_platform.device_type
|
||||
if not self.device_type:
|
||||
raise RuntimeError(
|
||||
"Failed to infer device type, please set "
|
||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||
"to turn on verbose logging to help debug the issue."
|
||||
)
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
if isinstance(self.device, str):
|
||||
self.device_type = self.device
|
||||
elif isinstance(self.device, torch.device):
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
107
vllm/config/ec_transfer.py
Normal file
107
vllm/config/ec_transfer.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
ECProducer = Literal["ec_producer", "ec_both"]
|
||||
ECConsumer = Literal["ec_consumer", "ec_both"]
|
||||
ECRole = Literal[ECProducer, ECConsumer]
|
||||
|
||||
|
||||
@config
|
||||
class ECTransferConfig:
|
||||
"""Configuration for distributed EC cache transfer."""
|
||||
|
||||
ec_connector: str | None = None
|
||||
"""The EC connector for vLLM to transmit EC caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: str | None = None
|
||||
"""The engine id for EC transfers."""
|
||||
|
||||
ec_buffer_device: str | None = "cuda"
|
||||
"""The device used by ec connector to buffer the EC cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
ec_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
ec_role: ECRole | None = None
|
||||
"""Whether this vLLM instance produces, consumes EC cache, or both. Choices
|
||||
are 'ec_producer', 'ec_consumer', 'ec_both'."""
|
||||
|
||||
ec_rank: int | None = None
|
||||
"""The rank of this vLLM instance in the EC cache transfer. Typical value:
|
||||
0 for encoder, 1 for pd instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
ec_parallel_size: int = 1
|
||||
"""The number of parallel instances for EC cache transfer. For
|
||||
PyNcclConnector, this should be 2."""
|
||||
|
||||
ec_ip: str = "127.0.0.1"
|
||||
"""The EC connector ip, used to build distributed connection."""
|
||||
|
||||
ec_port: int = 14579
|
||||
"""The EC connector port, used to build distributed connection."""
|
||||
|
||||
ec_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
ec_connector_module_path: str | None = None
|
||||
"""The Python module path to dynamically load the EC connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.ec_role is not None and self.ec_role not in get_args(ECRole):
|
||||
raise ValueError(
|
||||
f"Unsupported ec_role: {self.ec_role}. "
|
||||
f"Supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
if self.ec_connector is not None and self.ec_role is None:
|
||||
raise ValueError(
|
||||
"Please specify ec_role when ec_connector "
|
||||
f"is set, supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_ec_transfer_instance(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECRole)
|
||||
|
||||
@property
|
||||
def is_ec_producer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECProducer)
|
||||
|
||||
@property
|
||||
def is_ec_consumer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.ec_connector_extra_config.get(key, default)
|
||||
76
vllm/config/kernel.py
Normal file
76
vllm/config/kernel.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
MoEBackend = Literal[
|
||||
"auto",
|
||||
"triton",
|
||||
"deep_gemm",
|
||||
"cutlass",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_cutedsl",
|
||||
"marlin",
|
||||
"aiter",
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class KernelConfig:
|
||||
"""Configuration for kernel selection and warmup behavior."""
|
||||
|
||||
enable_flashinfer_autotune: bool = Field(default=None)
|
||||
"""If True, run FlashInfer autotuning during kernel warmup."""
|
||||
|
||||
moe_backend: MoEBackend = "auto"
|
||||
"""Backend for MoE expert computation kernels. Available options:
|
||||
|
||||
- "auto": Automatically select the best backend based on model and hardware\n
|
||||
- "triton": Use Triton-based fused MoE kernels\n
|
||||
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n
|
||||
- "cutlass": Use vLLM CUTLASS kernels\n
|
||||
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n
|
||||
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n
|
||||
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n
|
||||
- "marlin": Use Marlin kernels (weight-only quantization)\n
|
||||
- "aiter": Use AMD AITer kernels (ROCm only)"""
|
||||
|
||||
@field_validator("moe_backend", mode="before")
|
||||
@classmethod
|
||||
def _normalize_moe_backend(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.lower().replace("-", "_")
|
||||
return value
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialization is delayed."""
|
||||
if value is None:
|
||||
return value
|
||||
return handler(value)
|
||||
54
vllm/config/kv_events.py
Normal file
54
vllm/config/kv_events.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
class KVEventsConfig:
|
||||
"""Configuration for KV event publishing."""
|
||||
|
||||
enable_kv_cache_events: bool = False
|
||||
"""If True, enable KV cache events for tracking block storage and removal.
|
||||
Events can be published externally by zmq using the event publisher config.
|
||||
"""
|
||||
|
||||
publisher: Literal["null", "zmq"] = Field(default=None)
|
||||
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
||||
"""
|
||||
|
||||
endpoint: str = "tcp://*:5557"
|
||||
"""The zmq endpoint to use for publishing kv events.
|
||||
"""
|
||||
|
||||
replay_endpoint: str | None = None
|
||||
"""The zmq endpoint to use for replaying kv events.
|
||||
"""
|
||||
|
||||
buffer_steps: int = 10_000
|
||||
"""The number of steps to cache for replay endpoint. Will only save
|
||||
events from the last N steps for the replay endpoint.
|
||||
"""
|
||||
|
||||
hwm: int = 100_000
|
||||
"""The zmq high water mark for the event publisher. After queueing N events,
|
||||
events will start dropping if the consumer is not keeping up.
|
||||
"""
|
||||
|
||||
max_queue_size: int = 100_000
|
||||
"""The maximum number of events to queue while waiting for publishing.
|
||||
"""
|
||||
|
||||
topic: str = ""
|
||||
"""The topic to use for the event publisher. Consumers can subscribe to
|
||||
this topic to receive events.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.publisher is None:
|
||||
self.publisher = "zmq" if self.enable_kv_cache_events else "null"
|
||||
116
vllm/config/kv_transfer.py
Normal file
116
vllm/config/kv_transfer.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
kv_connector: str | None = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: str | None = None
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: str = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache. Choices are
|
||||
'cuda' and 'cpu'."""
|
||||
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
kv_role: KVRole | None = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||
|
||||
kv_rank: int | None = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
P2pNcclConnector, this should be 2."""
|
||||
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
kv_connector_module_path: str | None = None
|
||||
"""The Python module path to dynamically load the KV connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
enable_permute_local_kv: bool = False
|
||||
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
||||
|
||||
kv_load_failure_policy: Literal["recompute", "fail"] = "fail"
|
||||
"""Policy for handling KV cache load failures.
|
||||
'recompute': reschedule the request to recompute failed blocks
|
||||
'fail': immediately fail the request with an error finish reason (default)"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(
|
||||
f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}"
|
||||
)
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError(
|
||||
"Please specify kv_role when kv_connector "
|
||||
f"is set, supported roles are {get_args(KVRole)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
122
vllm/config/load.py
Normal file
122
vllm/config/load.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
else:
|
||||
LoadFormats = Any
|
||||
TensorizerConfig = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
class LoadConfig:
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
load_format: str | LoadFormats = "auto"
|
||||
"""The format of the model weights to load:\n
|
||||
- "auto" will try to load the weights in the safetensors format and fall
|
||||
back to the pytorch bin format if safetensors format is not available.\n
|
||||
- "pt" will load the weights in the pytorch bin format.\n
|
||||
- "safetensors" will load the weights in the safetensors format.\n
|
||||
- "npcache" will load the weights in pytorch format and store a numpy cache
|
||||
to speed up the loading.\n
|
||||
- "dummy" will initialize the weights with random values, which is mainly
|
||||
for profiling.\n
|
||||
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
|
||||
loading. See the Tensorize vLLM Model script in the Examples section for
|
||||
more information.\n
|
||||
- "runai_streamer" will load the Safetensors weights using Run:ai Model
|
||||
Streamer.\n
|
||||
- "runai_streamer_sharded" will load weights from pre-sharded checkpoint
|
||||
files using Run:ai Model Streamer.\n
|
||||
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
|
||||
- "sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.\n
|
||||
- "gguf" will load weights from GGUF format files (details specified in
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||
- "mistral" will load weights from consolidated safetensors files used by
|
||||
Mistral models.
|
||||
- Other custom values can be supported via plugins."""
|
||||
download_dir: str | None = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
safetensors_load_strategy: str = "lazy"
|
||||
"""Specifies the loading strategy for safetensors weights.
|
||||
- "lazy" (default): Weights are memory-mapped from the file. This enables
|
||||
on-demand loading and is highly efficient for models on local storage.
|
||||
- "eager": The entire file is read into CPU memory upfront before loading.
|
||||
This is recommended for models on network filesystems (e.g., Lustre, NFS)
|
||||
as it avoids inefficient random reads, significantly speeding up model
|
||||
initialization. However, it uses more CPU RAM.
|
||||
- "torchao": Weights are loaded in upfront and then reconstructed
|
||||
into torchao tensor subclasses. This is used when the checkpoint
|
||||
was quantized using torchao and saved using safetensors.
|
||||
Needs torchao >= 0.14.0
|
||||
"""
|
||||
model_loader_extra_config: dict | TensorizerConfig = Field(default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format."""
|
||||
device: str | None = None
|
||||
"""Device to which model weights will be loaded, default to
|
||||
device_config.device"""
|
||||
ignore_patterns: list[str] | str = Field(default_factory=lambda: ["original/**/*"])
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
"""Whether to enable tqdm for showing progress bar when loading model
|
||||
weights."""
|
||||
pt_load_map_location: str | dict[str, str] = "cpu"
|
||||
"""
|
||||
pt_load_map_location: the map location for loading pytorch checkpoint, to
|
||||
support loading checkpoints can only be loaded on certain devices like
|
||||
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
|
||||
mapping from different devices like from GPU 1 to GPU 0:
|
||||
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
|
||||
in dictionary needs to be double quoted for json parsing. For more details,
|
||||
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("load_format", mode="after")
|
||||
def _lowercase_load_format(cls, load_format: str) -> str:
|
||||
return load_format.lower()
|
||||
|
||||
@field_validator("ignore_patterns", mode="after")
|
||||
def _validate_ignore_patterns(
|
||||
cls, ignore_patterns: list[str] | str
|
||||
) -> list[str] | str:
|
||||
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
ignore_patterns,
|
||||
)
|
||||
|
||||
return ignore_patterns
|
||||
107
vllm/config/lora.py
Normal file
107
vllm/config/lora.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
else:
|
||||
ModelConfig = Any
|
||||
CacheConfig = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||
MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
|
||||
LoRAExtraVocabSize = Literal[256, 512]
|
||||
|
||||
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
max_lora_rank: MaxLoRARanks = 16
|
||||
"""Max LoRA rank."""
|
||||
max_loras: int = Field(default=1, ge=1)
|
||||
"""Max number of LoRAs in a single batch."""
|
||||
fully_sharded_loras: bool = False
|
||||
"""By default, only half of the LoRA computation is sharded with tensor
|
||||
parallelism. Enabling this will use the fully sharded layers. At high
|
||||
sequence length, max rank or tensor parallel size, this is likely faster.
|
||||
"""
|
||||
max_cpu_loras: int | None = None
|
||||
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
|
||||
`max_loras`."""
|
||||
lora_dtype: torch.dtype | LoRADType = "auto"
|
||||
"""Data type for LoRA. If auto, will default to base model dtype."""
|
||||
default_mm_loras: dict[str, str] | None = None
|
||||
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||
is only applicable to multimodal models and should be leveraged when a
|
||||
model always expects a LoRA to be active when a given modality is present.
|
||||
Note that currently, if a request provides multiple additional
|
||||
modalities, each of which have their own LoRA, we do NOT apply
|
||||
default_mm_loras because we currently only support one lora adapter
|
||||
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||
will be automatically assigned to 1-n with the names of the modalities
|
||||
in alphabetic order."""
|
||||
enable_tower_connector_lora: bool = False
|
||||
"""If `True`, LoRA support for the tower (vision encoder) and connector
|
||||
of multimodal models will be enabled. This is an experimental feature and
|
||||
currently only supports some MM models such as the Qwen VL series. The default
|
||||
is False."""
|
||||
specialize_active_lora: bool = False
|
||||
"""Whether to construct lora kernel grid by the number of active LoRA adapters.
|
||||
When set to True, separate cuda graphs will be captured for different counts
|
||||
of active LoRAs (powers of 2 up to max_loras), which can improve performance
|
||||
for variable LoRA usage patterns at the cost of increased startup time and
|
||||
memory usage. Only takes effect when cudagraph_specialize_lora is True.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.max_lora_rank)
|
||||
factors.append(self.max_loras)
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
factors.append(self.enable_tower_connector_lora)
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_lora_config(self) -> Self:
|
||||
if self.max_cpu_loras is None:
|
||||
self.max_cpu_loras = self.max_loras
|
||||
elif self.max_cpu_loras < self.max_loras:
|
||||
raise ValueError(
|
||||
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
|
||||
f"max_loras ({self.max_loras})."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.lora_dtype in (None, "auto"):
|
||||
self.lora_dtype = model_config.dtype
|
||||
elif isinstance(self.lora_dtype, str):
|
||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||
2056
vllm/config/model.py
Normal file
2056
vllm/config/model.py
Normal file
File diff suppressed because it is too large
Load Diff
57
vllm/config/model_arch.py
Normal file
57
vllm/config/model_arch.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class ModelArchitectureConfig:
|
||||
"""
|
||||
Configuration for model architecture that required by vLLM runtime
|
||||
"""
|
||||
|
||||
architectures: list[str] | None
|
||||
"""List of model architecture class names (e.g., ['LlamaForCausalLM']).
|
||||
It can be None upon calling `vllm_config.with_hf_config(config.text_config)`"""
|
||||
|
||||
model_type: str
|
||||
"""Model type identifier (e.g., 'llama', 'gpt_oss')."""
|
||||
|
||||
text_model_type: str | None
|
||||
"""Text model type identifier (e.g., 'llama4_text')."""
|
||||
|
||||
hidden_size: int
|
||||
"""Hidden size of the model."""
|
||||
|
||||
total_num_hidden_layers: int
|
||||
"""Number of hidden layers in the model."""
|
||||
|
||||
total_num_attention_heads: int
|
||||
"""Number of attention heads in the model."""
|
||||
|
||||
head_size: int
|
||||
"""Head dimension of the model."""
|
||||
|
||||
vocab_size: int
|
||||
"""Vocabulary size of the model."""
|
||||
|
||||
total_num_kv_heads: int
|
||||
"""Number of key value heads in the model."""
|
||||
|
||||
num_experts: int
|
||||
"""Number of experts in the model."""
|
||||
|
||||
quantization_config: dict[str, Any] | None
|
||||
"""Quantization configuration dictionary containing quantization parameters."""
|
||||
|
||||
is_deepseek_mla: bool
|
||||
"""Whether the model is a DeepSeek MLA model."""
|
||||
|
||||
derived_max_model_len_and_key: tuple[float, str | None]
|
||||
"""Derived maximum model length and key from the hf config."""
|
||||
281
vllm/config/multimodal.py
Normal file
281
vllm/config/multimodal.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, TypeAlias, TypedDict, final
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDummyOptions:
|
||||
"""Base options for generating dummy data during profiling."""
|
||||
|
||||
count: int = Field(999, ge=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class VideoDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy video data during profiling."""
|
||||
|
||||
num_frames: int | None = Field(None, gt=0)
|
||||
width: int | None = Field(None, gt=0)
|
||||
height: int | None = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class ImageDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy image data during profiling."""
|
||||
|
||||
width: int | None = Field(None, gt=0)
|
||||
height: int | None = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class AudioDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy audio data during profiling."""
|
||||
|
||||
length: int | None = Field(None, gt=0)
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: ImageDummyOptions
|
||||
"""Options for dummy images."""
|
||||
|
||||
video: VideoDummyOptions
|
||||
"""Options for dummy videos."""
|
||||
|
||||
audio: AudioDummyOptions
|
||||
"""Options for dummy audios."""
|
||||
|
||||
|
||||
MMEncoderTPMode = Literal["weights", "data"]
|
||||
MMCacheType = Literal["shm", "lru"]
|
||||
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type of dummy data.
|
||||
|
||||
The built-in modalities are defined by
|
||||
[`MultiModalDummyOptionsBuiltins`][vllm.config.multimodal.MultiModalDummyOptionsBuiltins].
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
language_model_only: bool = False
|
||||
"""If True, disables all multimodal inputs by setting all modality limits to 0.
|
||||
Equivalent to setting `--limit-mm-per-prompt` to 0 for every modality."""
|
||||
limit_per_prompt: MMDummyOptions = Field(default_factory=dict)
|
||||
"""The maximum number of input items and options allowed per
|
||||
prompt for each modality.
|
||||
|
||||
Defaults to 999 for each modality.
|
||||
|
||||
Legacy format (count only):
|
||||
{"image": 16, "video": 2}
|
||||
|
||||
Configurable format (with options):
|
||||
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
|
||||
"image": {"count": 5, "width": 512, "height": 512}}
|
||||
|
||||
Mixed format (combining both):
|
||||
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
|
||||
"height": 512}}
|
||||
"""
|
||||
enable_mm_embeds: bool = False
|
||||
"""If `True`, enables passing multimodal embeddings:
|
||||
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
|
||||
for the OpenAI-compatible server, this refers to chat messages with content
|
||||
`"type": "*_embeds"`.
|
||||
|
||||
When enabled with `--limit-mm-per-prompt` set to 0 for a modality,
|
||||
precomputed embeddings skip count validation for that modality,
|
||||
saving memory by not loading encoder modules while still enabling
|
||||
embeddings as an input. Limits greater than 0 still apply to embeddings.
|
||||
|
||||
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
|
||||
Only enable this flag for trusted users!"""
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'`"""
|
||||
mm_processor_kwargs: dict[str, object] | None = None
|
||||
"""Arguments to be forwarded to the model's processor for multi-modal data,
|
||||
e.g., image processor. Overrides for the multi-modal processor obtained
|
||||
from `transformers.AutoProcessor.from_pretrained`.
|
||||
|
||||
The available overrides depend on the model that is being run.
|
||||
|
||||
For example, for Phi-3-Vision:
|
||||
`{"num_crops": 4}`."""
|
||||
mm_processor_cache_gb: float = Field(default=4, ge=0)
|
||||
"""The size (in GiB) of the multi-modal processor cache, which is used to
|
||||
avoid re-processing past multi-modal inputs.
|
||||
|
||||
This cache is duplicated for each API process and engine core process,
|
||||
resulting in a total memory usage of
|
||||
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
|
||||
|
||||
Set to `0` to disable this cache completely (not recommended)."""
|
||||
mm_processor_cache_type: MMCacheType = "lru"
|
||||
"""Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
|
||||
use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
|
||||
mm_shm_cache_max_object_size_mb: int = Field(default=128, ge=0)
|
||||
"""Size limit (in MiB) for each object stored in the multi-modal processor
|
||||
shared memory cache. Only effective when `mm_processor_cache_type` is
|
||||
`"shm"`."""
|
||||
mm_encoder_only: bool = False
|
||||
"""
|
||||
When enabled, skips the language component of the model.
|
||||
|
||||
This is usually only valid in disaggregated Encoder process.
|
||||
"""
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
|
||||
"""Indicates how to optimize multi-modal encoder inference using tensor
|
||||
parallelism (TP).
|
||||
|
||||
- `"weights"`: Within the same vLLM engine, split the weights of
|
||||
each layer across TP ranks. (default TP behavior)\n
|
||||
- `"data"`: Within the same vLLM engine, split the batched input data
|
||||
across TP ranks to process the data in parallel, while hosting
|
||||
the full weights on each TP rank.
|
||||
This batch-level DP is not to be confused with API request-level
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP."""
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | None = None
|
||||
"""Optional override for the multi-modal encoder attention backend when
|
||||
using vision transformers. Accepts any value from
|
||||
`vllm.v1.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
|
||||
interleave_mm_strings: bool = False
|
||||
"""Enable fully interleaved support for multimodal prompts, while using
|
||||
--chat-template-content-format=string."""
|
||||
skip_mm_profiling: bool = False
|
||||
"""When enabled, skips multimodal memory profiling and only profiles with
|
||||
language backbone model during engine initialization.
|
||||
|
||||
This reduces engine startup time but shifts the responsibility to users for
|
||||
estimating the peak memory usage of the activation of multimodal encoder and
|
||||
embedding cache."""
|
||||
video_pruning_rate: float | None = Field(default=None, ge=0.0, lt=1.0)
|
||||
"""Sets pruning rate for video pruning via Efficient Video Sampling.
|
||||
Value sits in range [0;1) and determines fraction of media tokens
|
||||
from each video to be pruned.
|
||||
"""
|
||||
|
||||
@field_validator("limit_per_prompt", mode="before")
|
||||
@classmethod
|
||||
def _validate_limit_per_prompt(
|
||||
cls,
|
||||
value: dict[str, int | dict[str, int]],
|
||||
) -> MMDummyOptions:
|
||||
out: MMDummyOptions = {}
|
||||
|
||||
for k, v in value.items():
|
||||
# Handle legacy format where only count is specified
|
||||
if isinstance(v, int):
|
||||
v = {"count": v}
|
||||
|
||||
# Convert to the appropriate DummyOptions subclass
|
||||
if k == "video":
|
||||
out[k] = VideoDummyOptions(**v)
|
||||
elif k == "image":
|
||||
out[k] = ImageDummyOptions(**v)
|
||||
elif k == "audio":
|
||||
out[k] = AudioDummyOptions(**v)
|
||||
else:
|
||||
out[k] = BaseDummyOptions(**v)
|
||||
|
||||
return out
|
||||
|
||||
@field_validator("mm_encoder_attn_backend", mode="before")
|
||||
@classmethod
|
||||
def _validate_mm_encoder_attn_backend(
|
||||
cls, value: str | AttentionBackendEnum | None
|
||||
) -> AttentionBackendEnum | None:
|
||||
if isinstance(value, str) and value.upper() == "XFORMERS":
|
||||
raise ValueError(
|
||||
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
||||
"details). Please select a supported attention backend."
|
||||
)
|
||||
|
||||
if value is None or isinstance(value, AttentionBackendEnum):
|
||||
return value
|
||||
|
||||
assert isinstance(value, str), (
|
||||
"mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
|
||||
)
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_multimodal_config(self):
|
||||
if self.mm_processor_cache_type != "shm" and (
|
||||
self.mm_shm_cache_max_object_size_mb
|
||||
!= MultiModalConfig.mm_shm_cache_max_object_size_mb
|
||||
):
|
||||
raise ValueError(
|
||||
"'mm_shm_cache_max_object_size_mb' should only be set when "
|
||||
"'mm_processor_cache_type' is 'shm'."
|
||||
)
|
||||
return self
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = [
|
||||
self.mm_encoder_attn_backend.name
|
||||
if self.mm_encoder_attn_backend is not None
|
||||
else None,
|
||||
self.mm_encoder_tp_mode,
|
||||
]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
"""
|
||||
Get the maximum number of input items allowed per prompt
|
||||
for the given modality (backward compatible).
|
||||
"""
|
||||
if self.language_model_only:
|
||||
return 0
|
||||
|
||||
limit_data = self.limit_per_prompt.get(modality)
|
||||
|
||||
if limit_data is None:
|
||||
# Unspecified modality is set to 999 by default
|
||||
return 999
|
||||
|
||||
return limit_data.count
|
||||
|
||||
def merge_mm_processor_kwargs(
|
||||
self,
|
||||
inference_kwargs: Mapping[str, object],
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Get the keyword arguments to pass to the multi-modal processor
|
||||
according to the extra arguments passed during inference.
|
||||
"""
|
||||
kwargs = self.mm_processor_kwargs or {}
|
||||
return kwargs | dict(inference_kwargs)
|
||||
|
||||
def is_multimodal_pruning_enabled(self):
|
||||
return self.video_pruning_rate is not None and self.video_pruning_rate > 0
|
||||
152
vllm/config/observability.py
Normal file
152
vllm/config/observability.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
|
||||
from vllm import version
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@config
|
||||
class ObservabilityConfig:
|
||||
"""Configuration for observability - metrics and tracing."""
|
||||
|
||||
show_hidden_metrics_for_version: str | None = None
|
||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||
specified version. For example, if a previously deprecated metric has been
|
||||
hidden since the v0.7.0 release, you use
|
||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||
you migrate to new metrics. The metric is likely to be removed completely
|
||||
in an upcoming release."""
|
||||
|
||||
@cached_property
|
||||
def show_hidden_metrics(self) -> bool:
|
||||
"""Check if the hidden metrics should be shown."""
|
||||
if self.show_hidden_metrics_for_version is None:
|
||||
return False
|
||||
return version._prev_minor_version_was(self.show_hidden_metrics_for_version)
|
||||
|
||||
otlp_traces_endpoint: str | None = None
|
||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||
|
||||
collect_detailed_traces: list[DetailedTraceModules] | None = None
|
||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||
set, it will collect detailed traces for the specified modules. This
|
||||
involves use of possibly costly and or blocking operations and hence might
|
||||
have a performance impact.
|
||||
|
||||
Note that collecting detailed timing information for each request can be
|
||||
expensive."""
|
||||
|
||||
kv_cache_metrics: bool = False
|
||||
"""Enable KV cache residency metrics (lifetime, idle time, reuse gaps).
|
||||
Uses sampling to minimize overhead.
|
||||
Requires log stats to be enabled (i.e., --disable-log-stats not set)."""
|
||||
|
||||
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
|
||||
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
|
||||
|
||||
cudagraph_metrics: bool = False
|
||||
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
|
||||
dispatch modes, and their observed frequencies at every logging interval)."""
|
||||
|
||||
enable_layerwise_nvtx_tracing: bool = False
|
||||
"""Enable layerwise NVTX tracing. This traces the execution of each layer or
|
||||
module in the model and attach informations such as input/output shapes to
|
||||
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
|
||||
|
||||
enable_mfu_metrics: bool = False
|
||||
"""Enable Model FLOPs Utilization (MFU) metrics."""
|
||||
|
||||
enable_mm_processor_stats: bool = False
|
||||
"""Enable collection of timing statistics for multimodal processor operations.
|
||||
This is for internal use only (e.g., benchmarks) and is not exposed as a CLI
|
||||
argument."""
|
||||
|
||||
enable_logging_iteration_details: bool = False
|
||||
"""Enable detailed logging of iteration details.
|
||||
If set, vllm EngineCore will log iteration details
|
||||
This includes number of context/generation requests and tokens
|
||||
and the elapsed cpu time for the iteration."""
|
||||
|
||||
@cached_property
|
||||
def collect_model_forward_time(self) -> bool:
|
||||
"""Whether to collect model forward time for the request."""
|
||||
return self.collect_detailed_traces is not None and (
|
||||
"model" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def collect_model_execute_time(self) -> bool:
|
||||
"""Whether to collect model execute time for the request."""
|
||||
return self.collect_detailed_traces is not None and (
|
||||
"worker" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces
|
||||
)
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("show_hidden_metrics_for_version")
|
||||
@classmethod
|
||||
def _validate_show_hidden_metrics_for_version(cls, value: str | None) -> str | None:
|
||||
if value is not None:
|
||||
# Raises an exception if the string is not a valid version.
|
||||
parse(value)
|
||||
return value
|
||||
|
||||
@field_validator("otlp_traces_endpoint")
|
||||
@classmethod
|
||||
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
|
||||
if value is not None:
|
||||
from vllm.tracing import is_tracing_available, otel_import_error_traceback
|
||||
|
||||
if not is_tracing_available():
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}"
|
||||
)
|
||||
return value
|
||||
|
||||
@field_validator("collect_detailed_traces")
|
||||
@classmethod
|
||||
def _validate_collect_detailed_traces(
|
||||
cls, value: list[DetailedTraceModules] | None
|
||||
) -> list[DetailedTraceModules] | None:
|
||||
"""Handle the legacy case where users might provide a comma-separated
|
||||
string instead of a list of strings."""
|
||||
if value is not None and len(value) == 1 and "," in value[0]:
|
||||
value = cast(list[DetailedTraceModules], value[0].split(","))
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_tracing_config(self):
|
||||
if self.collect_detailed_traces and not self.otlp_traces_endpoint:
|
||||
raise ValueError(
|
||||
"collect_detailed_traces requires `--otlp-traces-endpoint` to be set."
|
||||
)
|
||||
return self
|
||||
153
vllm/config/offload.py
Normal file
153
vllm/config/offload.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Configuration for model weight offloading."""
|
||||
|
||||
import warnings
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
OffloadBackend = Literal["auto", "uva", "prefetch"]
|
||||
|
||||
|
||||
@config
|
||||
class UVAOffloadConfig:
|
||||
"""Configuration for UVA (Unified Virtual Addressing) CPU offloading.
|
||||
|
||||
Uses zero-copy access from CPU-pinned memory. Simple but requires
|
||||
fast CPU-GPU interconnect.
|
||||
"""
|
||||
|
||||
cpu_offload_gb: float = Field(default=0, ge=0)
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
increase the GPU memory size. For example, if you have one 24 GB GPU and
|
||||
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
|
||||
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
|
||||
Note that this requires fast CPU-GPU interconnect, as part of the model is
|
||||
loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||
This uses UVA (Unified Virtual Addressing) for zero-copy access.
|
||||
"""
|
||||
|
||||
cpu_offload_params: set[str] = Field(default_factory=set)
|
||||
"""The set of parameter name segments to target for CPU offloading.
|
||||
Unmatched parameters are not offloaded. If this set is empty, parameters
|
||||
are offloaded non-selectively until the memory limit defined by
|
||||
`cpu_offload_gb` is reached.
|
||||
Examples:
|
||||
- For parameter name "mlp.experts.w2_weight":
|
||||
- "experts" or "experts.w2_weight" will match.
|
||||
- "expert" or "w2" will NOT match (must be exact segments).
|
||||
This allows distinguishing parameters like "w2_weight" and "w2_weight_scale".
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
class PrefetchOffloadConfig:
|
||||
"""Configuration for prefetch-based CPU offloading.
|
||||
|
||||
Groups layers and uses async H2D prefetch to hide transfer latency.
|
||||
"""
|
||||
|
||||
offload_group_size: int = Field(default=0, ge=0)
|
||||
"""Group every N layers together. Offload last `offload_num_in_group`
|
||||
layers of each group. Default is 0 (disabled).
|
||||
Example: group_size=8, num_in_group=2 offloads layers 6,7,14,15,22,23,...
|
||||
Unlike cpu_offload_gb, this uses explicit async prefetching to hide transfer
|
||||
latency.
|
||||
"""
|
||||
|
||||
offload_num_in_group: int = Field(default=1, ge=1)
|
||||
"""Number of layers to offload per group.
|
||||
Must be <= offload_group_size. Default is 1."""
|
||||
|
||||
offload_prefetch_step: int = Field(default=1, ge=0)
|
||||
"""Number of layers to prefetch ahead.
|
||||
Higher values hide more latency but use more GPU memory. Default is 1."""
|
||||
|
||||
offload_params: set[str] = Field(default_factory=set)
|
||||
"""The set of parameter name segments to target for prefetch offloading.
|
||||
Unmatched parameters are not offloaded. If this set is empty, ALL
|
||||
parameters of each offloaded layer are offloaded.
|
||||
Uses segment matching: "w13_weight" matches "mlp.experts.w13_weight"
|
||||
but not "mlp.experts.w13_weight_scale".
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
class OffloadConfig:
|
||||
"""Configuration for model weight offloading to reduce GPU memory usage."""
|
||||
|
||||
offload_backend: OffloadBackend = "auto"
|
||||
"""The backend for weight offloading. Options:
|
||||
- "auto": Selects based on which sub-config has non-default values
|
||||
(prefetch if offload_group_size > 0, uva if cpu_offload_gb > 0).
|
||||
- "uva": UVA (Unified Virtual Addressing) zero-copy offloading.
|
||||
- "prefetch": Async prefetch with group-based layer offloading.
|
||||
"""
|
||||
|
||||
uva: UVAOffloadConfig = Field(default_factory=UVAOffloadConfig)
|
||||
"""Parameters for UVA offloading backend."""
|
||||
|
||||
prefetch: PrefetchOffloadConfig = Field(default_factory=PrefetchOffloadConfig)
|
||||
"""Parameters for prefetch offloading backend."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_offload_config(self) -> "OffloadConfig":
|
||||
"""Validate offload configuration constraints."""
|
||||
if self.offload_backend == "prefetch" or self.prefetch.offload_group_size > 0:
|
||||
if self.prefetch.offload_num_in_group > self.prefetch.offload_group_size:
|
||||
raise ValueError(
|
||||
f"offload_num_in_group ({self.prefetch.offload_num_in_group})"
|
||||
f" must be <= offload_group_size"
|
||||
f" ({self.prefetch.offload_group_size})"
|
||||
)
|
||||
if self.prefetch.offload_prefetch_step < 1:
|
||||
raise ValueError(
|
||||
f"offload_prefetch_step"
|
||||
f" ({self.prefetch.offload_prefetch_step})"
|
||||
f" must be >= 1 when prefetch offloading is enabled"
|
||||
f" (offload_group_size > 0)"
|
||||
)
|
||||
|
||||
# Warn if both backends have non-default values
|
||||
uva_active = self.uva.cpu_offload_gb > 0
|
||||
prefetch_active = self.prefetch.offload_group_size > 0
|
||||
if self.offload_backend == "uva" and prefetch_active:
|
||||
warnings.warn(
|
||||
"Prefetch offload fields are set but offload_backend='uva'. "
|
||||
"Prefetch settings will be ignored.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif self.offload_backend == "prefetch" and uva_active:
|
||||
warnings.warn(
|
||||
"UVA offload fields are set but offload_backend='prefetch'. "
|
||||
"UVA settings will be ignored.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif self.offload_backend == "auto" and uva_active and prefetch_active:
|
||||
warnings.warn(
|
||||
"Both UVA and prefetch offload fields are set with "
|
||||
"offload_backend='auto'. Prefetch backend will be selected. "
|
||||
"Set offload_backend explicitly to suppress this warning.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return self
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the offload configs.
|
||||
|
||||
All fields are included because PrefetchOffloader patches module
|
||||
forwards and inserts custom ops (wait_prefetch, start_prefetch)
|
||||
into the computation graph. Changing any offload setting can
|
||||
alter which layers are hooked and how prefetch indices are
|
||||
computed, so the compilation cache must distinguish them.
|
||||
"""
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors=set())
|
||||
hash_str = hash_factors(factors)
|
||||
return hash_str
|
||||
713
vllm/config/parallel.py
Normal file
713
vllm/config/parallel.py
Normal file
@@ -0,0 +1,713 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from typing_extensions import Self
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.v1.executor import Executor
|
||||
else:
|
||||
RuntimeEnv = Any
|
||||
PlacementGroup = Any
|
||||
Executor = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ExpertPlacementStrategy = Literal["linear", "round_robin"]
|
||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||
DataParallelBackend = Literal["ray", "mp"]
|
||||
EPLBPolicyOption = Literal["default"]
|
||||
All2AllBackend = Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class EPLBConfig:
|
||||
"""Configuration for Expert Parallel Load Balancing (EP)."""
|
||||
|
||||
window_size: int = 1000
|
||||
"""Window size for expert load recording."""
|
||||
step_interval: int = 3000
|
||||
"""
|
||||
Interval for rearranging experts in expert parallelism.
|
||||
|
||||
Note that if this is greater than the EPLB window size, only the metrics
|
||||
of the last `lb_window_size` steps will be used for rearranging experts.
|
||||
"""
|
||||
|
||||
num_redundant_experts: int = Field(default=0, ge=0)
|
||||
"""Number of redundant experts to use for expert parallelism."""
|
||||
|
||||
log_balancedness: bool = False
|
||||
"""
|
||||
Log the balancedness each step of expert parallelism.
|
||||
This is turned off by default since it will cause communication overhead.
|
||||
"""
|
||||
log_balancedness_interval: int = 1
|
||||
"""
|
||||
Interval for logging the balancedness.
|
||||
"""
|
||||
use_async: bool = False
|
||||
"""
|
||||
Whether to use non-blocking EPLB.
|
||||
"""
|
||||
|
||||
policy: EPLBPolicyOption = "default"
|
||||
"""The policy type for expert parallel load balancing (EPLB)."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_eplb_config(self) -> Self:
|
||||
if self.use_async and self.policy != "default":
|
||||
raise ValueError("Async EPLB is only supported with the default policy.")
|
||||
if self.log_balancedness and self.log_balancedness_interval <= 0:
|
||||
raise ValueError("log_balancedness_interval must be greater than 0.")
|
||||
return self
|
||||
|
||||
|
||||
@config
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
pipeline_parallel_size: int = 1
|
||||
"""Number of pipeline parallel groups."""
|
||||
tensor_parallel_size: int = 1
|
||||
"""Number of tensor parallel groups."""
|
||||
prefill_context_parallel_size: int = 1
|
||||
"""Number of prefill context parallel groups."""
|
||||
data_parallel_size: int = 1
|
||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||
the product of the tensor parallel size and data parallel size."""
|
||||
data_parallel_size_local: int = 1
|
||||
"""Number of local data parallel groups."""
|
||||
data_parallel_rank: int = 0
|
||||
"""Rank of the data parallel group."""
|
||||
data_parallel_rank_local: int | None = None
|
||||
"""Local rank of the data parallel group,
|
||||
set only in SPMD mode."""
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
"""IP of the data parallel master."""
|
||||
data_parallel_rpc_port: int = 29550
|
||||
"""Port for data parallel messaging."""
|
||||
data_parallel_master_port: int = 29500
|
||||
"""Port of the data parallel master."""
|
||||
data_parallel_backend: DataParallelBackend = "mp"
|
||||
"""Backend to use for data parallel, either "mp" or "ray"."""
|
||||
data_parallel_external_lb: bool = False
|
||||
"""Whether to use "external" DP LB mode. Applies only to online serving
|
||||
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
|
||||
wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank
|
||||
is provided explicitly to vllm serve."""
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
|
||||
and when data_parallel_size > 0. Enables running an AsyncLLM
|
||||
and API server on a "per-node" basis where vLLM load balances
|
||||
between local data parallel ranks, but an external LB balances
|
||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||
--data-parallel-start-rank."""
|
||||
is_moe_model: bool | None = None
|
||||
"""Whether the deployed model is MoE (if known)."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
enable_eplb: bool = False
|
||||
"""Enable expert parallelism load balancing for MoE layers."""
|
||||
eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
|
||||
"""Expert parallelism configuration."""
|
||||
expert_placement_strategy: ExpertPlacementStrategy = "linear"
|
||||
"""The expert placement strategy for MoE layers:\n
|
||||
- "linear": Experts are placed in a contiguous manner. For example, with 4
|
||||
experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
|
||||
experts [2, 3].\n
|
||||
- "round_robin": Experts are placed in a round-robin manner. For example,
|
||||
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
|
||||
will have experts [1, 3]. This strategy can help improve load balancing
|
||||
for grouped expert models with no redundant experts."""
|
||||
all2all_backend: All2AllBackend = "allgather_reducescatter"
|
||||
"""All2All backend for MoE expert parallel communication. Available options:
|
||||
|
||||
- "naive": Naive all2all implementation using broadcasts\n
|
||||
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
|
||||
- "pplx": Use pplx kernels\n
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels\n
|
||||
- "deepep_low_latency": Use deepep low-latency kernels\n
|
||||
- "mori": Use mori kernels\n
|
||||
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
|
||||
|
||||
max_parallel_loading_workers: int | None = None
|
||||
"""Maximum number of parallel loading workers when loading model
|
||||
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
||||
parallel and large models."""
|
||||
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
ubatch_size: int = 0
|
||||
"""Number of ubatch size."""
|
||||
|
||||
dbo_decode_token_threshold: int = 32
|
||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||
If the number of tokens in the request is greater than this threshold,
|
||||
microbatching will be used. Otherwise, the request will be processed in a
|
||||
single batch."""
|
||||
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
|
||||
"""The threshold for dual batch overlap for batches that contain one or more
|
||||
prefills. If the number of tokens in the request is greater than this
|
||||
threshold, microbatching will be used. Otherwise, the request will be
|
||||
processed in a single batch."""
|
||||
|
||||
disable_nccl_for_dp_synchronization: bool | None = Field(default=None)
|
||||
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
|
||||
to use Gloo instead of NCCL for its all reduce.
|
||||
|
||||
Defaults to True when async scheduling is enabled, False otherwise.
|
||||
"""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
ray_runtime_env: RuntimeEnv | None = None
|
||||
"""Ray runtime environment to pass to distributed workers."""
|
||||
|
||||
placement_group: PlacementGroup | None = None
|
||||
"""ray distributed model workers placement group."""
|
||||
|
||||
distributed_executor_backend: (
|
||||
str | DistributedExecutorBackend | type[Executor] | None
|
||||
) = None
|
||||
"""Backend to use for distributed model workers, either "ray" or "mp"
|
||||
(multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
|
||||
is less than or equal to the number of GPUs available, "mp" will be used to
|
||||
keep processing on a single host. Otherwise, an error will be raised. To use "mp"
|
||||
you must also set nnodes, and to use "ray" you must manually set
|
||||
distributed_executor_backend to "ray".
|
||||
|
||||
Note that tpu only support Ray for distributed inference."""
|
||||
|
||||
worker_cls: str = "auto"
|
||||
"""The full name of the worker class to use. If "auto", the worker class
|
||||
will be determined based on the platform."""
|
||||
sd_worker_cls: str = "auto"
|
||||
"""The full name of the worker class to use for speculative decoding.
|
||||
If "auto", the worker class will be determined based on the platform."""
|
||||
worker_extension_cls: str = ""
|
||||
"""The full name of the worker extension class to use. The worker extension
|
||||
class is dynamically inherited by the worker class. This is used to inject
|
||||
new attributes and methods to the worker class for use in collective_rpc
|
||||
calls."""
|
||||
master_addr: str = "127.0.0.1"
|
||||
"""distributed master address for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
master_port: int = 29501
|
||||
"""distributed master port for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
node_rank: int = 0
|
||||
"""distributed node rank for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
nnodes: int = 1
|
||||
"""num of nodes for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
|
||||
world_size: int = Field(init=False)
|
||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
_data_parallel_master_port_list: list[int] = Field(default_factory=list)
|
||||
"""List of open port auto-queried for data parallel messaging.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
"""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
dcp_kv_cache_interleave_size: int = 1
|
||||
"""
|
||||
Interleave size of kv_cache storage while using DCP.
|
||||
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
|
||||
and will be deprecated when PCP is fully supported.
|
||||
|
||||
"""
|
||||
cp_kv_cache_interleave_size: int = 1
|
||||
"""Interleave size of kv_cache storage while using DCP or PCP.
|
||||
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
|
||||
and `total_cp_world_size = pcp_world_size * dcp_world_size`.
|
||||
store interleave_size tokens on total_cp_rank i,
|
||||
then store next interleave_size tokens on total_cp_rank i+1.
|
||||
Interleave_size=1: token-level alignment, where token `i` is stored on
|
||||
total_cp_rank `i % total_cp_world_size`.
|
||||
Interleave_size=block_size: block-level alignment, where tokens are
|
||||
first populated to the preceding ranks. Tokens are then stored
|
||||
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
|
||||
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
|
||||
Block_size should be divisible by cp_kv_cache_interleave_size.
|
||||
"""
|
||||
|
||||
data_parallel_index: int = Field(init=False)
|
||||
"""Equal to the data parallel rank but not used for torch process groups
|
||||
and not overridden for dense models."""
|
||||
|
||||
_api_process_count: int = Field(default=1, gt=0)
|
||||
"""
|
||||
The number of API processes initialized.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
_api_process_rank: int = Field(default=0, ge=-1)
|
||||
"""
|
||||
The rank of this API process, or `-1` for engine core processes
|
||||
under API server scale-out.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||
return None if value is None else handler(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_parallel_config(self) -> Self:
|
||||
if self._api_process_rank >= self._api_process_count:
|
||||
raise ValueError(
|
||||
"Invalid value of `_api_process_rank`. "
|
||||
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
|
||||
f"but found: {self._api_process_rank}"
|
||||
)
|
||||
|
||||
if self.data_parallel_size_local > self.data_parallel_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_size_local ({self.data_parallel_size_local}) "
|
||||
f"must be <= data_parallel_size ({self.data_parallel_size})"
|
||||
)
|
||||
|
||||
if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
|
||||
raise ValueError(
|
||||
"data_parallel_external_lb can only be set when data_parallel_size > 1"
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
if not current_platform.is_cuda_alike():
|
||||
raise ValueError(
|
||||
"Expert parallelism load balancing is only supported on "
|
||||
"CUDA devices or ROCm devices now."
|
||||
)
|
||||
if not self.enable_expert_parallel:
|
||||
raise ValueError("enable_expert_parallel must be True to use EPLB.")
|
||||
if self.tensor_parallel_size * self.data_parallel_size <= 1:
|
||||
raise ValueError(
|
||||
"EPLB requires tensor_parallel_size or data_parallel_size "
|
||||
f"to be greater than 1, but got "
|
||||
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
|
||||
)
|
||||
else:
|
||||
if self.eplb_config.num_redundant_experts != 0:
|
||||
raise ValueError(
|
||||
"num_redundant_experts is set to "
|
||||
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
|
||||
"enabled. Either enable EPLB or unset "
|
||||
"num_redundant_experts."
|
||||
)
|
||||
|
||||
# Note(hc): In the current implementation of decode context
|
||||
# parallel(DCP), tp_size needs to be divisible by dcp_size,
|
||||
# because the world size does not change by dcp, it simply
|
||||
# reuses the GPUs of TP group, and split one TP group into
|
||||
# tp_size//dcp_size DCP groups.
|
||||
if self.tensor_parallel_size % self.decode_context_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"tp_size={self.tensor_parallel_size} must be divisible by"
|
||||
f"dcp_size={self.decode_context_parallel_size}."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
including data parallelism."""
|
||||
return self.world_size * self.data_parallel_size
|
||||
|
||||
@property
|
||||
def use_ubatching(self) -> bool:
|
||||
return self.enable_dbo or self.ubatch_size > 1
|
||||
|
||||
@property
|
||||
def num_ubatches(self) -> int:
|
||||
return 2 if self.enable_dbo else self.ubatch_size
|
||||
|
||||
@property
|
||||
def local_engines_only(self) -> bool:
|
||||
"""
|
||||
Client manages local+remote EngineCores in pure internal LB case.
|
||||
Client manages local EngineCores in hybrid and external LB case.
|
||||
"""
|
||||
return self.data_parallel_external_lb or self.data_parallel_hybrid_lb
|
||||
|
||||
def get_next_dp_init_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
processes that is related to data parallelism,
|
||||
e.g. both in the worker and in the engine, which
|
||||
can live in different processes. To avoid port conflicts, we
|
||||
pop a new port from the prepared port list each time we need to
|
||||
initialize a new process group related to data parallelism.
|
||||
"""
|
||||
if self._data_parallel_master_port_list:
|
||||
answer = self._data_parallel_master_port_list.pop()
|
||||
else:
|
||||
answer = self.data_parallel_master_port
|
||||
self.data_parallel_master_port += 1
|
||||
|
||||
return answer
|
||||
|
||||
def stateless_init_dp_group(self) -> ProcessGroup:
|
||||
# NOTE: In high-concurrency scenarios multiple processes
|
||||
# can pick the same (currently free) port through a race
|
||||
# condition when calling `get_open_port()`. When the first
|
||||
# process binds the port the others will subsequently fail
|
||||
# with `torch.distributed.DistNetworkError: EADDRINUSE`.
|
||||
# To make the initialization more robust we retry a few times
|
||||
# with a fresh port whenever this specific error is observed.
|
||||
from torch.distributed import DistNetworkError
|
||||
|
||||
from vllm.distributed.utils import (
|
||||
stateless_init_torch_distributed_process_group,
|
||||
)
|
||||
|
||||
max_retries = 5
|
||||
last_exc: Exception | None = None
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# use gloo since the engine process might not have cuda device
|
||||
return stateless_init_torch_distributed_process_group(
|
||||
self.data_parallel_master_ip,
|
||||
self.get_next_dp_init_port(),
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend=current_platform.dist_backend,
|
||||
)
|
||||
except DistNetworkError as e:
|
||||
# We only want to retry when the root cause is EADDRINUSE.
|
||||
if "EADDRINUSE" in str(e):
|
||||
logger.warning("Address already in use. Retrying with a new port.")
|
||||
last_exc = e
|
||||
continue # try again with a new port
|
||||
raise e
|
||||
|
||||
# If we get here all retries have failed.
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
@property
|
||||
def use_sequence_parallel_moe(self) -> bool:
|
||||
return (
|
||||
self.all2all_backend
|
||||
in (
|
||||
"allgather_reducescatter",
|
||||
"naive",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
)
|
||||
and self.enable_expert_parallel
|
||||
and self.tensor_parallel_size > 1
|
||||
and self.data_parallel_size > 1
|
||||
)
|
||||
|
||||
@property
|
||||
def node_rank_within_dp(self) -> int:
|
||||
return self.node_rank % self.nnodes_within_dp
|
||||
|
||||
@property
|
||||
def nnodes_within_dp(self) -> int:
|
||||
if self.nnodes == 1:
|
||||
return 1
|
||||
data_parallel_node_size = (
|
||||
self.data_parallel_size // self.data_parallel_size_local
|
||||
)
|
||||
return self.nnodes // data_parallel_node_size
|
||||
|
||||
@property
|
||||
def local_world_size(self) -> int:
|
||||
return self.world_size // self.nnodes_within_dp
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
|
||||
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
|
||||
# dp rank 0: has_unfinished_seqs=True
|
||||
# dp rank 1: has_unfinished_seqs=False
|
||||
# aggregated: has_unfinished_seqs=True
|
||||
# so this is an OR operation, i.e. MAX in integers
|
||||
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
|
||||
aggregated_has_unfinished = bool(tensor.item())
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@staticmethod
|
||||
def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int:
|
||||
if kv_cache_memory == -1:
|
||||
kv_cache_memory = torch.iinfo(torch.int64).max
|
||||
tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu")
|
||||
# we cannot use broadcast for stateless dp group since it depends
|
||||
# on global rank
|
||||
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
|
||||
return tensor.item()
|
||||
|
||||
def compute_hash(self):
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
|
||||
This hash is also used for DP worker configuration validation
|
||||
to prevent hangs from mismatched collective communication patterns.
|
||||
"""
|
||||
ignored_factors = {
|
||||
# Derived/runtime topology, networking, or launch details
|
||||
"data_parallel_rank",
|
||||
"data_parallel_rank_local",
|
||||
"data_parallel_size_local",
|
||||
"data_parallel_index",
|
||||
"data_parallel_backend",
|
||||
"data_parallel_external_lb",
|
||||
"data_parallel_hybrid_lb",
|
||||
"data_parallel_master_ip",
|
||||
"data_parallel_master_port",
|
||||
"_data_parallel_master_port_list",
|
||||
"data_parallel_rpc_port",
|
||||
"rank",
|
||||
"master_addr",
|
||||
"master_port",
|
||||
"node_rank",
|
||||
"nnodes",
|
||||
"max_parallel_loading_workers",
|
||||
"disable_custom_all_reduce",
|
||||
"ray_workers_use_nsight",
|
||||
"ray_runtime_env",
|
||||
"placement_group",
|
||||
"distributed_executor_backend",
|
||||
"worker_cls",
|
||||
"sd_worker_cls",
|
||||
"worker_extension_cls",
|
||||
"_api_process_count",
|
||||
"_api_process_rank",
|
||||
}
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Continue with the rest of the initialization
|
||||
self.world_size = (
|
||||
self.pipeline_parallel_size
|
||||
* self.tensor_parallel_size
|
||||
* self.prefill_context_parallel_size
|
||||
)
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
logger.info("Using external launcher for distributed inference.")
|
||||
self.world_size *= self.data_parallel_size
|
||||
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
# For external launcher,
|
||||
# we need to set the data parallel rank automatically
|
||||
self.data_parallel_rank = int(os.environ["RANK"]) // (
|
||||
self.world_size // self.data_parallel_size
|
||||
)
|
||||
logger.info(
|
||||
"Set data_parallel_rank to %d automatically.",
|
||||
self.data_parallel_rank,
|
||||
)
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
|
||||
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
|
||||
raise ValueError(
|
||||
f"data_parallel_rank ({self.data_parallel_rank})"
|
||||
f" must be in the range [0, {self.data_parallel_size})"
|
||||
)
|
||||
else:
|
||||
# Otherwise fall back to env vars (e.g. for offline SPMD case).
|
||||
self.data_parallel_size = envs.VLLM_DP_SIZE
|
||||
self.data_parallel_rank = envs.VLLM_DP_RANK
|
||||
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
|
||||
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
|
||||
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
|
||||
if self.data_parallel_size > 1 and self.is_moe_model is False:
|
||||
raise ValueError(
|
||||
"Offline data parallel mode is not supported/useful"
|
||||
" for dense models."
|
||||
)
|
||||
|
||||
self.data_parallel_index = self.data_parallel_rank
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
logger.info("Disabling V1 multiprocessing for external launcher.")
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
backend: DistributedExecutorBackend = "mp"
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||
backend = "uni"
|
||||
elif current_platform.is_cuda() and self.nnodes > 1:
|
||||
backend = "mp"
|
||||
elif (
|
||||
current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size
|
||||
):
|
||||
gpu_count = cuda_device_count_stateless()
|
||||
raise ValueError(
|
||||
f"World size ({self.world_size}) is larger than the number of "
|
||||
f"available GPUs ({gpu_count}) in this node. If this is "
|
||||
"intentional and you are using:\n"
|
||||
"- ray, set '--distributed-executor-backend ray'.\n"
|
||||
"- multiprocessing, set '--nnodes' appropriately."
|
||||
)
|
||||
elif self.data_parallel_backend == "ray":
|
||||
logger.info(
|
||||
"Using ray distributed inference because "
|
||||
"data_parallel_backend is ray"
|
||||
)
|
||||
backend = "ray"
|
||||
elif ray_found:
|
||||
if self.placement_group:
|
||||
backend = "ray"
|
||||
else:
|
||||
from ray import is_initialized as ray_is_initialized
|
||||
|
||||
if ray_is_initialized():
|
||||
from ray.util import get_current_placement_group
|
||||
|
||||
if get_current_placement_group():
|
||||
backend = "ray"
|
||||
self.distributed_executor_backend = backend
|
||||
logger.debug("Defaulting to use %s for distributed inference", backend)
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size == 1:
|
||||
self.distributed_executor_backend = "uni"
|
||||
|
||||
if self.max_parallel_loading_workers is not None:
|
||||
logger.warning(
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
allowed_backends = ("mp", "uni", "external_launcher")
|
||||
if (
|
||||
self.distributed_executor_backend not in allowed_backends
|
||||
and self.nnodes > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"nnodes > 1 can only be set when distributed executor "
|
||||
"backend is mp, uni or external_launcher."
|
||||
)
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
isinstance(self.distributed_executor_backend, type)
|
||||
and getattr(self.distributed_executor_backend, "uses_ray", False)
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
# Lazy import to avoid circular import
|
||||
from vllm.v1.executor import Executor
|
||||
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_is_batch_invariant():
|
||||
self.disable_custom_all_reduce = True
|
||||
|
||||
if (
|
||||
self.distributed_executor_backend is not None
|
||||
and not isinstance(self.distributed_executor_backend, str)
|
||||
and not (
|
||||
isinstance(self.distributed_executor_backend, type)
|
||||
and issubclass(self.distributed_executor_backend, Executor)
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend "
|
||||
f"{self.distributed_executor_backend}. Supported "
|
||||
"values are 'ray', 'mp' 'uni', 'external_launcher', "
|
||||
" custom Executor subclass or its import path."
|
||||
)
|
||||
if self.use_ray:
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
ray_utils.assert_ray_available()
|
||||
|
||||
if not current_platform.use_custom_allreduce():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.debug(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on current platform."
|
||||
)
|
||||
if self.nnodes > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.debug(
|
||||
"Disabled the custom all-reduce since we are running on multi-node."
|
||||
)
|
||||
if self.ray_workers_use_nsight and not self.use_ray:
|
||||
raise ValueError(
|
||||
"Unable to use nsight profiling unless workers run with Ray."
|
||||
)
|
||||
|
||||
return self
|
||||
146
vllm/config/pooler.py
Normal file
146
vllm/config/pooler.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
|
||||
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
|
||||
|
||||
TokenPoolingType = Literal["ALL", "STEP"]
|
||||
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
|
||||
|
||||
|
||||
@config
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: SequencePoolingType | TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method used for pooling.
|
||||
|
||||
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
|
||||
with this field. Alternatively, users can set `seq_pooling_type` and
|
||||
`tok_pooling_type` explicitly.
|
||||
|
||||
This field is mainly for user convenience. Internal code should always use
|
||||
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
|
||||
"""
|
||||
|
||||
seq_pooling_type: SequencePoolingType | None = None
|
||||
"""
|
||||
The pooling method used for sequence pooling.
|
||||
"""
|
||||
|
||||
tok_pooling_type: TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method used for tokenwise pooling.
|
||||
"""
|
||||
|
||||
use_activation: bool | None = None
|
||||
"""
|
||||
Whether to apply activation function to the pooler outputs.
|
||||
`None` uses the pooler's default, which is `True` in most cases.
|
||||
"""
|
||||
|
||||
## for embedding models
|
||||
dimensions: int | None = None
|
||||
"""
|
||||
Reduce the dimensions of embeddings if model
|
||||
support matryoshka representation. Defaults to None.
|
||||
"""
|
||||
enable_chunked_processing: bool = False
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
chunks, processed separately, and then aggregated using weighted averaging.
|
||||
This allows embedding models to handle arbitrarily long text without CUDA
|
||||
errors. Defaults to False.
|
||||
"""
|
||||
max_embed_len: int | None = None
|
||||
"""
|
||||
Maximum input length allowed for embedding generation. When set, allows
|
||||
inputs longer than max_embed_len to be accepted for embedding models.
|
||||
When an input exceeds max_embed_len, it will be handled according to
|
||||
the original max_model_len validation logic.
|
||||
Defaults to None (i.e. set to max_model_len).
|
||||
"""
|
||||
|
||||
## for classification models
|
||||
logit_bias: float | None = None
|
||||
"""
|
||||
If provided, apply classification logit biases. Defaults to None.
|
||||
"""
|
||||
|
||||
## for reward models
|
||||
step_tag_id: int | None = None
|
||||
"""
|
||||
If set, only the score corresponding to the `step_tag_id` in the
|
||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||
are returned.
|
||||
"""
|
||||
returned_token_ids: list[int] | None = None
|
||||
"""
|
||||
A list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of `good_token` and `bad_token` in the
|
||||
`math-shepherd-mistral-7b-prm` model.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if pooling_type := self.pooling_type:
|
||||
if self.seq_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `seq_pooling_type`"
|
||||
)
|
||||
if self.tok_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `tok_pooling_type`"
|
||||
)
|
||||
|
||||
if pooling_type in SEQ_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.seq_pooling_type = pooling_type
|
||||
elif pooling_type in TOK_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.tok_pooling_type = pooling_type
|
||||
else:
|
||||
raise NotImplementedError(pooling_type)
|
||||
|
||||
def get_seq_pooling_type(self) -> SequencePoolingType:
|
||||
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.seq_pooling_type
|
||||
|
||||
def get_tok_pooling_type(self) -> TokenPoolingType:
|
||||
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.tok_pooling_type
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
124
vllm/config/profiler.py
Normal file
124
vllm/config/profiler.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ProfilerKind = Literal["torch", "cuda"]
|
||||
|
||||
|
||||
def _is_uri_path(path: str) -> bool:
|
||||
"""Check if path is a URI (scheme://...), excluding Windows drive letters.
|
||||
|
||||
Supports custom URI schemes like gs://, s3://, hdfs://, etc.
|
||||
These paths should not be converted to absolute paths.
|
||||
"""
|
||||
if "://" in path:
|
||||
scheme = path.split("://")[0]
|
||||
# Windows drive letters are single characters (e.g., C://)
|
||||
# Valid URI schemes have more than one character
|
||||
return len(scheme) > 1
|
||||
return False
|
||||
|
||||
|
||||
@config
|
||||
class ProfilerConfig:
|
||||
"""Dataclass which contains profiler config for the engine."""
|
||||
|
||||
profiler: ProfilerKind | None = None
|
||||
"""Which profiler to use. Defaults to None. Options are:
|
||||
|
||||
- 'torch': Use PyTorch profiler.\n
|
||||
- 'cuda': Use CUDA profiler."""
|
||||
|
||||
torch_profiler_dir: str = ""
|
||||
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
|
||||
worker's traces (CPU & GPU) will be saved under this directory. Note that
|
||||
it must be an absolute path."""
|
||||
|
||||
torch_profiler_with_stack: bool = True
|
||||
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
|
||||
|
||||
torch_profiler_with_flops: bool = False
|
||||
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_use_gzip: bool = True
|
||||
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
|
||||
|
||||
torch_profiler_dump_cuda_time_total: bool = True
|
||||
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
|
||||
|
||||
torch_profiler_record_shapes: bool = False
|
||||
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_with_memory: bool = False
|
||||
"""If `True`, enables memory profiling in the torch profiler.
|
||||
Disabled by default."""
|
||||
|
||||
ignore_frontend: bool = False
|
||||
"""If `True`, disables the front-end profiling of AsyncLLM when using the
|
||||
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
|
||||
since the front-end profiling does not track iterations and will capture the
|
||||
entire range.
|
||||
"""
|
||||
|
||||
delay_iterations: int = Field(default=0, ge=0)
|
||||
"""Number of engine iterations to skip before starting profiling.
|
||||
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
|
||||
"""
|
||||
|
||||
max_iterations: int = Field(default=0, ge=0)
|
||||
"""Maximum number of engine iterations to profile after starting profiling.
|
||||
Defaults to 0, meaning no limit.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_profiler_config(self) -> Self:
|
||||
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
|
||||
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
|
||||
logger.warning_once(
|
||||
"Using 'torch' profiler with delay_iterations or max_iterations "
|
||||
"while ignore_frontend is False may result in high overhead."
|
||||
)
|
||||
|
||||
profiler_dir = self.torch_profiler_dir
|
||||
if profiler_dir and self.profiler != "torch":
|
||||
raise ValueError(
|
||||
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
|
||||
)
|
||||
if self.profiler == "torch" and not profiler_dir:
|
||||
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
|
||||
|
||||
# Support any URI scheme (gs://, s3://, hdfs://, etc.)
|
||||
# These paths should not be converted to absolute paths
|
||||
if profiler_dir and not _is_uri_path(profiler_dir):
|
||||
self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
|
||||
|
||||
return self
|
||||
300
vllm/config/scheduler.py
Normal file
300
vllm/config/scheduler.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
|
||||
|
||||
@config
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration."""
|
||||
|
||||
max_model_len: InitVar[int]
|
||||
"""Maximum length of a sequence (including prompt and generated text).
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
provide fallbacks and validate other attributes."""
|
||||
|
||||
is_encoder_decoder: InitVar[bool]
|
||||
"""True if the model is an encoder-decoder model.
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
disable chunked prefill and prefix caching for encoder-decoder models.
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
|
||||
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
|
||||
|
||||
runner_type: RunnerType = "generate"
|
||||
"""The runner type to launch for the model."""
|
||||
|
||||
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
|
||||
"""Maximum number of tokens that can be processed in a single iteration.
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
max_num_scheduled_tokens: int | None = Field(default=None)
|
||||
"""Maximum number of tokens that the scheduler may issue in a single iteration.
|
||||
|
||||
This is usually equal to max_num_batched_tokens, but can be smaller in cases
|
||||
when the model might append tokens into the batch (such as speculative decoding).
|
||||
Defaults to max_num_batched_tokens."""
|
||||
|
||||
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
|
||||
"""Maximum number of sequences to be processed in a single iteration.
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
max_num_partial_prefills: int = Field(default=1, ge=1)
|
||||
"""For chunked prefill, the maximum number of sequences that can be
|
||||
partially prefilled concurrently."""
|
||||
|
||||
max_long_partial_prefills: int = Field(default=1, ge=1)
|
||||
"""For chunked prefill, the maximum number of prompts longer than
|
||||
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
||||
this less than max_num_partial_prefills will allow shorter prompts to jump
|
||||
the queue in front of longer prompts in some cases, improving latency."""
|
||||
|
||||
long_prefill_token_threshold: int = 0
|
||||
"""For chunked prefill, a request is considered long if the prompt is
|
||||
longer than this number of tokens."""
|
||||
|
||||
enable_chunked_prefill: bool = True
|
||||
"""If True, prefill requests can be chunked based
|
||||
on the remaining `max_num_batched_tokens`.
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
is_multimodal_model: bool = False
|
||||
"""True if the model is multimodal."""
|
||||
|
||||
# TODO (ywang96): Make this configurable.
|
||||
max_num_encoder_input_tokens: int = Field(init=False)
|
||||
"""Multimodal encoder compute budget, only used in V1.
|
||||
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||
|
||||
# TODO (ywang96): Make this configurable.
|
||||
encoder_cache_size: int = Field(init=False)
|
||||
"""Multimodal encoder cache size, only used in V1.
|
||||
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||
|
||||
policy: SchedulerPolicy = "fcfs"
|
||||
"""The scheduling policy to use:\n
|
||||
- "fcfs" means first come first served, i.e. requests are handled in order
|
||||
of arrival.\n
|
||||
- "priority" means requests are handled based on given priority (lower
|
||||
value means earlier handling) and time of arrival deciding any ties)."""
|
||||
|
||||
disable_chunked_mm_input: bool = False
|
||||
"""If set to true and chunked prefill is enabled, we do not want to
|
||||
partially schedule a multimodal item. Only used in V1
|
||||
This ensures that if a request has a mixed prompt
|
||||
(like text tokens TTTT followed by image tokens IIIIIIIIII) where only
|
||||
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
||||
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
|
||||
|
||||
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
|
||||
# (default) or "mod.custom_class".
|
||||
scheduler_cls: str | type[object] | None = Field(default=None)
|
||||
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
|
||||
the default scheduler. Can be a class directly or the path to a class of
|
||||
form "mod.custom_class"."""
|
||||
|
||||
disable_hybrid_kv_cache_manager: bool | None = None
|
||||
"""If set to True, KV cache manager will allocate the same size of KV cache
|
||||
for all attention layers even if there are multiple type of attention layers
|
||||
like full attention and sliding window attention.
|
||||
If set to None, the default value will be determined based on the environment
|
||||
and starting configuration.
|
||||
"""
|
||||
|
||||
async_scheduling: bool | None = Field(default=None)
|
||||
"""If set to False, disable async scheduling. Async scheduling helps to
|
||||
avoid gaps in GPU utilization, leading to better latency and throughput.
|
||||
"""
|
||||
|
||||
stream_interval: int = Field(default=1, ge=1)
|
||||
"""The interval (or buffer size) for streaming in terms of token length.
|
||||
A smaller value (1) makes streaming smoother by sending each token immediately,
|
||||
while a larger value (e.g., 10) reduces host overhead and may increase throughput
|
||||
by batching multiple tokens before sending."""
|
||||
|
||||
@staticmethod
|
||||
def default_factory(**kwargs):
|
||||
"""
|
||||
Factory method to create `SchedulerConfig` with default values for `InitVar`s.
|
||||
"""
|
||||
if "max_model_len" not in kwargs:
|
||||
kwargs["max_model_len"] = 8192
|
||||
if "is_encoder_decoder" not in kwargs:
|
||||
kwargs["is_encoder_decoder"] = False
|
||||
return SchedulerConfig(**kwargs)
|
||||
|
||||
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
||||
if self.scheduler_cls is None:
|
||||
if self.async_scheduling:
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
|
||||
return AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
|
||||
return Scheduler
|
||||
|
||||
# This warning can be removed once the Scheduler interface is
|
||||
# finalized and we can maintain support for scheduler classes that
|
||||
# implement it
|
||||
logger.warning_once(
|
||||
"Using custom scheduler class %s. This scheduler interface is "
|
||||
"not public and compatibility may not be maintained.",
|
||||
self.scheduler_cls,
|
||||
)
|
||||
if not isinstance(self.scheduler_cls, str):
|
||||
return cast(type["SchedulerInterface"], self.scheduler_cls)
|
||||
return resolve_obj_by_qualname(self.scheduler_cls)
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
|
||||
# max_num_batched_tokens need to be included in the hash due
|
||||
# to two reasons:
|
||||
# 1. LoRA creates static buffers based on max_num_batched_tokens.
|
||||
# The tensor sizes and strides get captured in the torch.compile
|
||||
# graph explicitly.
|
||||
# 2. Inductor decides whether using 32-bit or 64-bit indexing integer
|
||||
# based on the data sizes. `max_num_batched_tokens` has an
|
||||
# impact on that. For more details, please check
|
||||
# https://github.com/vllm-project/vllm/issues/29585
|
||||
factors.append(self.max_num_batched_tokens)
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||
return None if value is None else handler(value)
|
||||
|
||||
def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
|
||||
if is_encoder_decoder:
|
||||
# Chunked prefill should be disabled for encoder-decoder models.
|
||||
self.disable_chunked_mm_input = True
|
||||
self.enable_chunked_prefill = False
|
||||
self.long_prefill_token_threshold = 0
|
||||
logger.info(
|
||||
"Encoder-decoder models do not support chunked prefill nor"
|
||||
" prefix caching; disabling both."
|
||||
)
|
||||
|
||||
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
|
||||
self.encoder_cache_size = self.max_num_batched_tokens
|
||||
|
||||
if self.enable_chunked_prefill:
|
||||
logger.info(
|
||||
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(max_model_len * 0.04)
|
||||
|
||||
logger.info(
|
||||
"Concurrent partial prefills enabled with "
|
||||
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
|
||||
"long_prefill_token_threshold=%d",
|
||||
self.max_num_partial_prefills,
|
||||
self.max_long_partial_prefills,
|
||||
self.long_prefill_token_threshold,
|
||||
)
|
||||
|
||||
self.verify_max_model_len(max_model_len)
|
||||
|
||||
def verify_max_model_len(self, max_model_len: int) -> Self:
|
||||
if (
|
||||
self.max_num_batched_tokens < max_model_len
|
||||
and not self.enable_chunked_prefill
|
||||
):
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len."
|
||||
)
|
||||
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs})."
|
||||
)
|
||||
|
||||
if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
|
||||
logger.warning(
|
||||
"max_num_batched_tokens (%d) exceeds max_num_seqs "
|
||||
"* max_model_len (%d). This may lead to unexpected behavior.",
|
||||
self.max_num_batched_tokens,
|
||||
self.max_num_seqs * max_model_len,
|
||||
)
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if not self.enable_chunked_prefill:
|
||||
raise ValueError(
|
||||
"Chunked prefill must be enabled to set "
|
||||
"max_num_partial_prefills > 1."
|
||||
)
|
||||
|
||||
if self.long_prefill_token_threshold > max_model_len:
|
||||
raise ValueError(
|
||||
"long_prefill_token_threshold "
|
||||
f"({self.long_prefill_token_threshold}) cannot be greater "
|
||||
f"than the max_model_len ({max_model_len})."
|
||||
)
|
||||
|
||||
if self.max_long_partial_prefills > self.max_num_partial_prefills:
|
||||
raise ValueError(
|
||||
f"{self.max_long_partial_prefills=} must be less than or equal to "
|
||||
f"{self.max_num_partial_prefills=}."
|
||||
)
|
||||
|
||||
return self
|
||||
789
vllm/config/speculative.py
Normal file
789
vllm/config/speculative.py
Normal file
@@ -0,0 +1,789 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import LoadConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_hf_text_config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.quantization as me_quant
|
||||
else:
|
||||
PretrainedConfig = Any
|
||||
|
||||
me_quant = LazyLoader(
|
||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MTPModelTypes = Literal[
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"glm4_moe_lite_mtp",
|
||||
"glm_ocr_mtp",
|
||||
"ernie_mtp",
|
||||
"nemotron_h_mtp",
|
||||
"exaone_moe_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"qwen3_5_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
"step3p5_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"suffix",
|
||||
EagleModelTypes,
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding."""
|
||||
|
||||
enforce_eager: bool | None = None
|
||||
"""Override the default enforce_eager from model_config"""
|
||||
# General speculative decoding control
|
||||
num_speculative_tokens: int = Field(default=None, gt=0)
|
||||
"""The number of speculative tokens, if provided. It will default to the
|
||||
number in the draft model config if present, otherwise, it is required."""
|
||||
model: str | None = None
|
||||
"""The name of the draft model, eagle head, or additional weights, if
|
||||
provided."""
|
||||
method: SpeculativeMethod | None = None
|
||||
"""The name of the speculative method to use. If users provide and set the
|
||||
`model` param, the speculative method type will be detected automatically
|
||||
if possible, if `model` param is not provided, the method name must be
|
||||
provided.
|
||||
|
||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||
`prompt_lookup_min` should be considered."""
|
||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
tensor_parallel_size: int | None = None
|
||||
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
|
||||
warn users when they mistakenly provide the wrong argument."""
|
||||
|
||||
# Draft model configuration
|
||||
quantization: me_quant.QuantizationMethods | None = None
|
||||
"""Quantization method that was used to quantize the draft model weights.
|
||||
If `None`, we assume the model weights are not quantized. Note that it only
|
||||
takes effect when using the draft model-based speculative method."""
|
||||
max_model_len: int | None = Field(default=None, ge=1)
|
||||
"""The maximum model length of the draft model. Used when testing the
|
||||
ability to skip speculation for some sequences."""
|
||||
revision: str | None = None
|
||||
"""The specific model version to use for the draft model. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use the
|
||||
default version."""
|
||||
code_revision: str | None = None
|
||||
"""The specific revision to use for the draft model code on Hugging Face
|
||||
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
|
||||
will use the default version."""
|
||||
|
||||
# Advanced control
|
||||
disable_padded_drafter_batch: bool = False
|
||||
"""Disable input padding for speculative decoding. If set to True,
|
||||
speculative input batches can contain sequences of different lengths,
|
||||
which may only be supported by certain attention backends. This currently
|
||||
only affects the EAGLE method of speculation."""
|
||||
use_local_argmax_reduction: bool = False
|
||||
"""Use vocab-parallel local argmax instead of all-gathering full logits
|
||||
for draft token generation. Reduces communication from O(vocab_size) to
|
||||
O(2 * tp_size) per token. Only applies to greedy draft selection in
|
||||
non-tree speculation."""
|
||||
|
||||
# Ngram proposer configuration
|
||||
prompt_lookup_max: int | None = Field(default=None, ge=1)
|
||||
"""Maximum size of ngram token window when using Ngram proposer, required
|
||||
when method is set to ngram."""
|
||||
prompt_lookup_min: int | None = Field(default=None, ge=1)
|
||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||
provided. Defaults to 1."""
|
||||
|
||||
# Alternative drafting strategies
|
||||
speculative_token_tree: str | None = None
|
||||
"""Specifies the tree structure for speculative token generation.
|
||||
"""
|
||||
parallel_drafting: bool = False
|
||||
"""Enable parallel drafting, where all speculative tokens are generated
|
||||
in parallel rather than sequentially. This can improve performance but
|
||||
requires the speculative model be trained to support parallel drafting.
|
||||
Only compatible with EAGLE and draft model methods."""
|
||||
|
||||
# required configuration params passed from engine
|
||||
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the target model."""
|
||||
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the target model."""
|
||||
|
||||
# params generated in the post-init stage
|
||||
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the draft model initialized internal."""
|
||||
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the draft model initialized internal."""
|
||||
|
||||
# Suffix decoding configuration
|
||||
suffix_decoding_max_tree_depth: int = 24
|
||||
"""The maximum depth of the suffix decoding global and prompt trees. The
|
||||
tree depth limits the sum of the prefix match and speculation lengths."""
|
||||
|
||||
suffix_decoding_max_cached_requests: int = 10000
|
||||
"""The maximum number of requests to cache in the global suffix tree. If
|
||||
exceeded, will trigger eviction in FIFO order. If set to 0, the global
|
||||
suffix tree is disabled and past responses are not cached (prompt trees
|
||||
are still used)."""
|
||||
|
||||
suffix_decoding_max_spec_factor: float = 1.0
|
||||
"""The maximum spec factor for suffix decoding. The spec factor controls
|
||||
speculation lengths based on the prefix match length: max_spec_tokens =
|
||||
max_spec_factor * prefix_match_length."""
|
||||
|
||||
suffix_decoding_min_token_prob: float = 0.1
|
||||
"""The minimum token probability for suffix decoding. Will only speculate
|
||||
tokens with estimated probability (based on frequency counts) greater than
|
||||
or equal to this value."""
|
||||
|
||||
draft_load_config: LoadConfig | None = None
|
||||
"""Load config for the draft model. If not specified, will use the load
|
||||
config from the target model."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
initial_architecture = hf_config.architectures[0]
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
|
||||
)
|
||||
if hf_config.model_type in ("pangu_ultra_moe"):
|
||||
hf_config.model_type = "pangu_ultra_moe_mtp"
|
||||
if hf_config.model_type == "pangu_ultra_moe_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
hf_config.model_type = "mimo_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_lite_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeLiteMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
|
||||
hf_config.model_type = "glm_ocr_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["GlmOcrMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "ernie4_5_moe":
|
||||
hf_config.model_type = "ernie_mtp"
|
||||
if hf_config.model_type == "ernie_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
|
||||
)
|
||||
|
||||
if (
|
||||
hf_config.model_type == "nemotron_h"
|
||||
and hasattr(hf_config, "num_nextn_predict_layers")
|
||||
and hf_config.num_nextn_predict_layers > 0
|
||||
):
|
||||
# Check if this is an MTP variant
|
||||
hf_config.model_type = "nemotron_h_mtp"
|
||||
if hf_config.model_type == "nemotron_h_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "qwen3_next":
|
||||
hf_config.model_type = "qwen3_next_mtp"
|
||||
if hf_config.model_type == "qwen3_next_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "exaone_moe":
|
||||
hf_config.model_type = "exaone_moe_mtp"
|
||||
if hf_config.model_type == "exaone_moe_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
|
||||
is_moe = hf_config.model_type == "qwen3_5_moe"
|
||||
hf_config.model_type = "qwen3_5_mtp"
|
||||
n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
|
||||
}
|
||||
)
|
||||
if hf_config.model_type == "longcat_flash":
|
||||
hf_config.model_type = "longcat_flash_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "step3p5":
|
||||
hf_config.model_type = "step3p5_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})
|
||||
|
||||
if initial_architecture == "MistralLarge3ForCausalLM":
|
||||
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
|
||||
|
||||
return hf_config
|
||||
|
||||
def __post_init__(self):
|
||||
# Note: "method" is a new parameter that helps to extend the
|
||||
# configuration of non-model-based proposers, and the "model" parameter
|
||||
# will be used to set the draft model, eagle head, or additional weight
|
||||
# when needed. If users do not specify "method", the speculative method
|
||||
# will be detected automatically if possible. If the speculative method
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
# infer method from user args
|
||||
if self.method is None:
|
||||
if self.model in ("ngram", "[ngram]"):
|
||||
self.method = "ngram"
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
|
||||
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
|
||||
logger.warning(
|
||||
"method `%s` is deprecated and replaced with mtp.", self.method
|
||||
)
|
||||
self.method = "mtp"
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
if self.method == "mtp":
|
||||
if self.target_model_config is None:
|
||||
raise ValueError("target_model_config must be present for mtp")
|
||||
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
|
||||
# FIXME(luccafong): cudagraph with v32 MTP is not supported,
|
||||
# remove this when the issue is fixed.
|
||||
self.enforce_eager = True
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
# --quantization fp8 with a bf16 checkpoint.
|
||||
if not self.quantization:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
)
|
||||
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
# Set default values if not provided
|
||||
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
|
||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||
self.prompt_lookup_min = 5
|
||||
self.prompt_lookup_max = 5
|
||||
elif self.prompt_lookup_min is None:
|
||||
if self.prompt_lookup_max is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_min = self.prompt_lookup_max
|
||||
elif self.prompt_lookup_max is None:
|
||||
if self.prompt_lookup_min is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_max = self.prompt_lookup_min
|
||||
|
||||
# Validate values
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||
f"be <= prompt_lookup_max={self.prompt_lookup_max}"
|
||||
)
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if self.model is not None:
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
||||
spec_target_max_model_len=self.target_model_config.max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
config_format=self.target_model_config.config_format,
|
||||
)
|
||||
|
||||
# Automatically detect the method
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
pass
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
# AngelSlim/Qwen3-8B_eagle3
|
||||
elif "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
self.method = "mlp_speculator"
|
||||
elif self.draft_model_config.hf_config.model_type in get_args(
|
||||
MTPModelTypes
|
||||
):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"Enabling num_speculative_tokens > 1 will run "
|
||||
"multiple times of forward on same MTP layer"
|
||||
",which may result in lower acceptance rate"
|
||||
)
|
||||
elif self.draft_model_config.hf_config.model_type in (
|
||||
"longcat_flash_mtp"
|
||||
):
|
||||
self.method = "longcat_flash_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"LongCat MTP models only have "
|
||||
"one layer. Might need some code changes "
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif self.method == "draft_model":
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported speculative method: '{self.method}'"
|
||||
)
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
if isinstance(
|
||||
self.draft_model_config.hf_config,
|
||||
(EAGLEConfig, SpeculatorsConfig),
|
||||
):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config,
|
||||
method=self.method,
|
||||
model_type="eagle",
|
||||
)
|
||||
# EAGLEConfig primarily updates architectures, so update
|
||||
# all architectures-related fields in draft_model_config
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = (
|
||||
self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = (
|
||||
self.num_speculative_tokens
|
||||
)
|
||||
|
||||
n_predict = getattr(
|
||||
self.draft_model_config.hf_config, "n_predict", None
|
||||
)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif (
|
||||
self.num_speculative_tokens > n_predict
|
||||
and self.num_speculative_tokens % n_predict != 0
|
||||
):
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}"
|
||||
)
|
||||
|
||||
if self.speculative_token_tree is None:
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"A speculative model was provided, but neither "
|
||||
"`speculative_token_tree` nor `num_speculative_tokens` "
|
||||
"was provided"
|
||||
)
|
||||
|
||||
# Generate chain of tokens.
|
||||
self.speculative_token_tree = str(
|
||||
[(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
|
||||
)
|
||||
else:
|
||||
# Sort the token tree breadth-first.
|
||||
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
||||
self.speculative_token_tree = str(
|
||||
sorted(tree_choices, key=lambda t: (len(t), t))
|
||||
)
|
||||
|
||||
self.draft_tensor_parallel_size = (
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config,
|
||||
)
|
||||
)
|
||||
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
)
|
||||
)
|
||||
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config, self.draft_tensor_parallel_size
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def _validate_suffix_decoding(self):
|
||||
if not has_arctic_inference():
|
||||
raise ImportError(
|
||||
"Arctic Inference is required for suffix decoding. "
|
||||
"Install via `pip install arctic-inference==0.1.1`."
|
||||
)
|
||||
if self.num_speculative_tokens is None:
|
||||
# Suffix decoding decides the actual number of speculative tokens
|
||||
# dynamically and treats num_speculative_tokens as a maximum limit.
|
||||
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
|
||||
logger.warning(
|
||||
"Defaulted num_speculative_tokens to %s for suffix decoding.",
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
# Validate values
|
||||
if self.suffix_decoding_max_tree_depth < 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_tree_depth="
|
||||
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
|
||||
)
|
||||
if self.suffix_decoding_max_cached_requests < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_cached_requests="
|
||||
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
|
||||
)
|
||||
if self.suffix_decoding_max_spec_factor < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_spec_factor="
|
||||
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
|
||||
)
|
||||
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_min_token_prob="
|
||||
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len: int | None,
|
||||
draft_max_model_len: int,
|
||||
target_max_model_len: int,
|
||||
) -> int:
|
||||
"""Determine the max sequence len for the draft model. This is usually
|
||||
the draft_max_model_len, but may be the target_max_model_len if it is
|
||||
less than the draft_max_model_len, or may be speculative_max_model_len
|
||||
if it is specified.
|
||||
|
||||
This is necessary so that sequences do not exceed the capacity of the
|
||||
draft model or the target model.
|
||||
|
||||
speculative_max_model_len is mainly used for testing that sequences can
|
||||
skip speculation.
|
||||
"""
|
||||
|
||||
if speculative_max_model_len is not None:
|
||||
if speculative_max_model_len > draft_max_model_len:
|
||||
raise ValueError(
|
||||
f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {draft_max_model_len=}"
|
||||
)
|
||||
|
||||
if speculative_max_model_len > target_max_model_len:
|
||||
raise ValueError(
|
||||
f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {target_max_model_len=}"
|
||||
)
|
||||
|
||||
return speculative_max_model_len
|
||||
|
||||
return min(
|
||||
draft_max_model_len,
|
||||
target_max_model_len,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_and_get_draft_tp(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int | None,
|
||||
draft_hf_config: PretrainedConfig,
|
||||
) -> int:
|
||||
"""
|
||||
Verifies and adjusts the tensor parallel size for a draft model
|
||||
specified using speculative_draft_tensor_parallel_size.
|
||||
"""
|
||||
# If speculative_draft_tensor_parallel_size is unset then set it
|
||||
# appropriately else verify that it is set correctly.
|
||||
if speculative_draft_tensor_parallel_size is None:
|
||||
if draft_hf_config.model_type == "mlp_speculator":
|
||||
speculative_draft_tensor_parallel_size = 1
|
||||
if target_parallel_config.tensor_parallel_size > 1:
|
||||
logger.warning(
|
||||
"%s cannot currently be run with tp>1; "
|
||||
"setting speculative_draft_tensor_parallel_size=1",
|
||||
draft_hf_config.model_type,
|
||||
)
|
||||
else:
|
||||
speculative_draft_tensor_parallel_size = (
|
||||
target_parallel_config.tensor_parallel_size
|
||||
)
|
||||
elif speculative_draft_tensor_parallel_size not in (
|
||||
1,
|
||||
target_parallel_config.tensor_parallel_size,
|
||||
):
|
||||
raise ValueError(
|
||||
f"{speculative_draft_tensor_parallel_size=} cannot be "
|
||||
f"other value than 1 or target model tensor_parallel_size"
|
||||
)
|
||||
return speculative_draft_tensor_parallel_size
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
This is mostly a copy of the target parallel config, except the tp_size.
|
||||
"""
|
||||
draft_parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
|
||||
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
|
||||
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
|
||||
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
|
||||
placement_group=target_parallel_config.placement_group,
|
||||
)
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
if self.tensor_parallel_size is not None:
|
||||
raise ValueError(
|
||||
"'tensor_parallel_size' is not a valid argument in the "
|
||||
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
"speculative model unless the draft model config contains an "
|
||||
"n_predict parameter."
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens <= 0:
|
||||
raise ValueError(
|
||||
"Expected num_speculative_tokens to be greater "
|
||||
f"than zero ({self.num_speculative_tokens})."
|
||||
)
|
||||
|
||||
if self.draft_model_config:
|
||||
self.draft_model_config.verify_with_parallel_config(
|
||||
self.draft_parallel_config
|
||||
)
|
||||
|
||||
eagle3_target_supported = [
|
||||
"llama",
|
||||
"qwen",
|
||||
"minicpm",
|
||||
"gpt_oss",
|
||||
"hunyuan_vl",
|
||||
"hunyuan_v1_dense",
|
||||
"afmoe",
|
||||
"nemotron_h",
|
||||
]
|
||||
if (
|
||||
self.method == "eagle3"
|
||||
and self.target_model_config
|
||||
and not any(
|
||||
supported_model in self.target_model_config.hf_text_config.model_type
|
||||
for supported_model in eagle3_target_supported
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
)
|
||||
self.verify_equal_vocab_size_if_draft_model()
|
||||
return self
|
||||
|
||||
def verify_equal_vocab_size_if_draft_model(self):
|
||||
if (
|
||||
self.method == "draft_model"
|
||||
and self.target_model_config is not None
|
||||
and self.draft_model_config is not None
|
||||
):
|
||||
target_vocab_size = self.target_model_config.get_vocab_size()
|
||||
draft_vocab_size = self.draft_model_config.get_vocab_size()
|
||||
if target_vocab_size != draft_vocab_size:
|
||||
raise ValueError(
|
||||
f"Target and draft model should have the same vocabulary size. "
|
||||
f"Target model vocab_size={target_vocab_size}. "
|
||||
f"Draft model vocab_size={draft_vocab_size}. "
|
||||
f"Using models with different tokenizers can cause out-of-bounds "
|
||||
f"errors during speculative decoding."
|
||||
)
|
||||
|
||||
@property
|
||||
def max_num_new_slots_for_drafting(self) -> int:
|
||||
"""
|
||||
Calculate the maximum number of new slots that might be added to the batch
|
||||
when drafting.
|
||||
"""
|
||||
slots_per_req = 0 # for serial non-draft-model methods, no change needed
|
||||
if self.parallel_drafting:
|
||||
# For parallel drafting, we need one new slot per 'masked' token
|
||||
slots_per_req = self.num_speculative_tokens - 1
|
||||
if self.uses_draft_model():
|
||||
# For draft model-based speculation, we need one new slot per request
|
||||
# Since we do not slice the draft tokens
|
||||
slots_per_req += 1
|
||||
return slots_per_req
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def uses_draft_model(self) -> bool:
|
||||
return self.method == "draft_model"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
39
vllm/config/speech_to_text.py
Normal file
39
vllm/config/speech_to_text.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
sample_rate: float = 16_000
|
||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int | None = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected.
|
||||
`None` means audio duration can be unlimited and won't be chunked."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
splitting long audio. This helps maintain context across chunk boundaries
|
||||
and improves transcription quality at split points."""
|
||||
|
||||
min_energy_split_window_size: int | None = 1600
|
||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||
audio chunks. The algorithm looks for the quietest moment within this
|
||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||
at 16kHz. If None, no chunking will be done."""
|
||||
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return (
|
||||
self.min_energy_split_window_size is not None
|
||||
and self.max_audio_clip_s is not None
|
||||
)
|
||||
76
vllm/config/structured_outputs.py
Normal file
76
vllm/config/structured_outputs.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
StructuredOutputsBackend = Literal[
|
||||
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class StructuredOutputsConfig:
|
||||
"""Dataclass which contains structured outputs config for the engine."""
|
||||
|
||||
backend: StructuredOutputsBackend = "auto"
|
||||
"""Which engine will be used for structured outputs (e.g. JSON schema,
|
||||
regex, etc) by default. With "auto", we will make opinionated choices
|
||||
based on request contents and what the backend libraries currently support,
|
||||
so the behavior is subject to change in each release."""
|
||||
disable_fallback: bool = False
|
||||
"""If `True`, vLLM will not fallback to a different backend on error."""
|
||||
disable_any_whitespace: bool = False
|
||||
"""If `True`, json output will always be compact without any whitespace.
|
||||
If `False`, the model may generate whitespace between JSON fields,
|
||||
which is still valid JSON. This is only supported for xgrammar
|
||||
and guidance backends."""
|
||||
disable_additional_properties: bool = False
|
||||
"""If `True`, the `guidance` backend will not use `additionalProperties`
|
||||
in the JSON schema. This is only supported for the `guidance` backend and
|
||||
is used to better align its behaviour with `outlines` and `xgrammar`."""
|
||||
reasoning_parser: str = ""
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format."""
|
||||
reasoning_parser_plugin: str = ""
|
||||
"""Path to a dynamically reasoning parser plugin that can be dynamically
|
||||
loaded and registered."""
|
||||
enable_in_reasoning: bool = False
|
||||
"""Whether to use structured input for reasoning."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_structured_output_config(self) -> Self:
|
||||
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
|
||||
raise ValueError(
|
||||
"disable_any_whitespace is only supported for "
|
||||
"xgrammar and guidance backends."
|
||||
)
|
||||
if self.disable_additional_properties and self.backend != "guidance":
|
||||
raise ValueError(
|
||||
"disable_additional_properties is only supported "
|
||||
"for the guidance backend."
|
||||
)
|
||||
return self
|
||||
447
vllm/config/utils.py
Normal file
447
vllm/config/utils.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for vLLM config dataclasses."""
|
||||
|
||||
import ast
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Callable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, field, fields, is_dataclass
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.fields import Field as PydanticField
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import dataclass_transform, runtime_checkable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
else:
|
||||
DataclassInstance = Any
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
|
||||
|
||||
|
||||
@dataclass_transform(field_specifiers=(PydanticField,))
|
||||
def config(
|
||||
cls: type[ConfigT] | None = None,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
|
||||
"""Decorator to create a pydantic dataclass with default config. The default config
|
||||
for the dataclass forbids extra fields.
|
||||
|
||||
All config classes in vLLM should use this decorator.
|
||||
|
||||
Args:
|
||||
cls: The class to decorate
|
||||
config: The pydantic ConfigDict to use. If provided, it will be merged with
|
||||
the default config.
|
||||
**kwargs: Additional arguments to pass to pydantic.dataclass."""
|
||||
# Extra fields are forbidden by default
|
||||
merged_config = ConfigDict(extra="forbid")
|
||||
if config is not None:
|
||||
merged_config.update(config)
|
||||
|
||||
def decorator(cls):
|
||||
return dataclass(cls, config=merged_config, **kwargs)
|
||||
|
||||
# Called with arguments: @config(config=...)
|
||||
if cls is None:
|
||||
return decorator
|
||||
# Called without arguments: @config
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def get_field(cls: ConfigType, name: str) -> Any:
|
||||
"""Get the default factory field of a dataclass by name. Used for getting
|
||||
default factory fields in `EngineArgs`."""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The given class is not a dataclass.")
|
||||
try:
|
||||
named_field = next(f for f in fields(cls) if f.name == name)
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
|
||||
|
||||
# The arguments to copy to the new field
|
||||
default = named_field.default
|
||||
default_factory = named_field.default_factory
|
||||
init = named_field.init
|
||||
|
||||
# Handle pydantic.Field
|
||||
if isinstance(default, FieldInfo):
|
||||
if default.init is not None:
|
||||
init = default.init
|
||||
if default.default_factory is not None:
|
||||
default_factory = cast(Callable[[], Any], default.default_factory)
|
||||
default = MISSING
|
||||
else:
|
||||
default = default.default
|
||||
|
||||
if default is MISSING and default_factory is MISSING:
|
||||
logger.warning_once(
|
||||
"%s.%s has no default or default factory.", cls.__name__, name
|
||||
)
|
||||
return field(default=default, default_factory=default_factory, init=init)
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return get_field(cls, name).init
|
||||
|
||||
|
||||
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
|
||||
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
|
||||
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
|
||||
of `dataclasses.field`"""
|
||||
cls = type(dataclass_instance)
|
||||
dataclass_dict = dataclass_instance.__dict__
|
||||
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
|
||||
dataclass_dict.update(kwargs)
|
||||
return cls(**dataclass_dict)
|
||||
|
||||
|
||||
def getattr_iter(
|
||||
object: object,
|
||||
names: Sequence[str],
|
||||
default: Any | None = None,
|
||||
default_factory: Callable[[], Any] | None = None,
|
||||
warn: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
|
||||
In the case where the first name in `names` is the preferred name, and
|
||||
any other names are deprecated aliases, setting `warn=True` will log a
|
||||
warning when a deprecated name is used.
|
||||
"""
|
||||
for i, name in enumerate(names):
|
||||
if hasattr(object, name):
|
||||
if warn and i > 0:
|
||||
logger.warning_once(
|
||||
"%s contains a deprecated attribute name '%s'. "
|
||||
"Please use the preferred attribute name '%s' instead.",
|
||||
type(object).__name__,
|
||||
name,
|
||||
names[0],
|
||||
)
|
||||
return getattr(object, name)
|
||||
return default_factory() if default_factory is not None else default
|
||||
|
||||
|
||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
"""
|
||||
Get any docstrings placed after attribute assignments in a class body.
|
||||
|
||||
https://davidism.com/mit-license/
|
||||
"""
|
||||
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
|
||||
if not isinstance(cls_node, ast.ClassDef):
|
||||
raise TypeError("Given object was not a class.")
|
||||
|
||||
out = {}
|
||||
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (
|
||||
not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)
|
||||
):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
|
||||
# An assignment can have multiple targets (a = b = v), but an
|
||||
# annotated assignment only has one target.
|
||||
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||
|
||||
for target in targets:
|
||||
# Must be assigning to a plain name.
|
||||
if not isinstance(target, ast.Name):
|
||||
continue
|
||||
|
||||
out[target.id] = doc
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsHash(Protocol):
|
||||
def compute_hash(self) -> str: ...
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
def metrics_info(self) -> dict[str, str]: ...
|
||||
|
||||
|
||||
def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
|
||||
processed_overrides = {}
|
||||
for field_name, value in overrides.items():
|
||||
assert hasattr(config, field_name), (
|
||||
f"{type(config)} has no field `{field_name}`"
|
||||
)
|
||||
current_value = getattr(config, field_name)
|
||||
if is_dataclass(current_value) and not is_dataclass(value):
|
||||
assert isinstance(value, dict), (
|
||||
f"Overrides to {type(config)}.{field_name} must be a dict"
|
||||
f" or {type(current_value)}, but got {type(value)}"
|
||||
)
|
||||
value = update_config(
|
||||
current_value, # type: ignore[type-var]
|
||||
value,
|
||||
)
|
||||
processed_overrides[field_name] = value
|
||||
return replace(config, **processed_overrides)
|
||||
|
||||
|
||||
def normalize_value(x):
|
||||
"""Return a stable, JSON-serializable canonical form for hashing.
|
||||
Order: primitives, special types (Enum, callable, torch.dtype, Path), then
|
||||
generic containers (Mapping/Set/Sequence) with recursion.
|
||||
"""
|
||||
# Fast path
|
||||
if x is None or isinstance(x, (bool, int, float, str)):
|
||||
return x
|
||||
|
||||
# Enums: tag with FQN to avoid primitive collisions.
|
||||
# Ex: Enum(1) vs int(1) -> ("module.QualName", value).
|
||||
if isinstance(x, enum.Enum):
|
||||
enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
return (enum_type, normalize_value(x.value))
|
||||
|
||||
# Classes (types) are accepted and canonicalized by their fully-qualified
|
||||
# name (module.qualname) for a stable identifier.
|
||||
# Instances are only accepted if they expose uuid(); otherwise they are
|
||||
# rejected to avoid under-hashing object state.
|
||||
|
||||
# Callables: accept classes only; reject funcs/lambdas/methods.
|
||||
# Used by LogitsProcessor types and ModelConfig.hf_overrides.
|
||||
if isinstance(x, type):
|
||||
module = getattr(x, "__module__", "")
|
||||
qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
|
||||
return ".".join([p for p in (module, qual) if p]) or repr(x)
|
||||
|
||||
# Prefer stable uuid identifiers for objects that provide them, even if
|
||||
# they are callable instances (e.g., InductorPass wrappers).
|
||||
if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
|
||||
return x.uuid()
|
||||
|
||||
if callable(x):
|
||||
raise TypeError("normalize_value: function or callable instance unsupported")
|
||||
|
||||
# Torch dtype: stringify (torch.float64 -> "torch.float64").
|
||||
# We rely on the string form here; dtype-bearing fields that need additional
|
||||
# disambiguation should encode that at the config layer.
|
||||
if isinstance(x, torch.dtype):
|
||||
return str(x)
|
||||
|
||||
# Bytes
|
||||
if isinstance(x, (bytes, bytearray)):
|
||||
return x.hex()
|
||||
|
||||
# Paths (canonicalize)
|
||||
if isinstance(x, pathlib.Path):
|
||||
try:
|
||||
return str(x.expanduser().resolve())
|
||||
except Exception:
|
||||
return str(x)
|
||||
|
||||
# Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
|
||||
if is_dataclass(x):
|
||||
type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
items = tuple(
|
||||
(f.name, normalize_value(getattr(x, f.name)))
|
||||
for f in sorted(fields(x), key=lambda f: f.name)
|
||||
)
|
||||
return (type_fqn, items)
|
||||
|
||||
# Containers (generic)
|
||||
if isinstance(x, Mapping):
|
||||
return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
|
||||
if isinstance(x, Set):
|
||||
return tuple(sorted(repr(normalize_value(v)) for v in x))
|
||||
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
|
||||
return tuple(normalize_value(v) for v in x)
|
||||
|
||||
# PretrainedConfig
|
||||
if hasattr(x, "to_json_string") and callable(x.to_json_string):
|
||||
return x.to_json_string()
|
||||
|
||||
# Unsupported type: e.g., modules, generators, open files, or objects
|
||||
# without a stable JSON/UUID representation. Hard-error to avoid
|
||||
# under-hashing.
|
||||
# If you hit this, either reshape your config to use supported primitives
|
||||
# and containers, or extend normalize_value to provide a stable encoding
|
||||
# (e.g., via uuid() or to_json_string()) for this type.
|
||||
raise TypeError(
|
||||
f"normalize_value: unsupported type '{type(x).__name__}'. "
|
||||
"Ensure config values use supported primitives/containers or add a "
|
||||
"stable representation for this type."
|
||||
)
|
||||
|
||||
|
||||
def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
|
||||
"""Gets the factors used for hashing a config class.
|
||||
- Includes all dataclass fields not in `ignored_factors`.
|
||||
- Errors on non-normalizable values.
|
||||
"""
|
||||
factors: dict[str, object] = {}
|
||||
for dc_field in fields(config):
|
||||
factor = dc_field.name
|
||||
if factor in ignored_factors:
|
||||
continue
|
||||
value = getattr(config, factor, None)
|
||||
try:
|
||||
factors[factor] = normalize_value(value)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"get_hash_factors: unsupported type for key '{factor}' "
|
||||
f"({type(value).__name__})"
|
||||
) from e
|
||||
return factors
|
||||
|
||||
|
||||
def hash_factors(items: dict[str, object]) -> str:
|
||||
"""Return a SHA-256 hex digest of the canonical items structure."""
|
||||
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
"""
|
||||
A range of numbers.
|
||||
Inclusive of start, inclusive of end.
|
||||
"""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def is_single_size(self) -> bool:
|
||||
return self.start == self.end
|
||||
|
||||
def __contains__(self, size: int) -> bool:
|
||||
# Inclusive of start, inclusive of end
|
||||
return self.start <= size <= self.end
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Range):
|
||||
return False
|
||||
return self.start == other.start and self.end == other.end
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.start, self.end))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({self.start}, {self.end})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def handle_deprecated(
|
||||
config: ConfigT,
|
||||
old_name: str,
|
||||
new_name_or_names: str | list[str],
|
||||
removal_version: str,
|
||||
) -> None:
|
||||
old_val = getattr(config, old_name)
|
||||
if old_val is None:
|
||||
return
|
||||
|
||||
if isinstance(new_name_or_names, str):
|
||||
new_names = [new_name_or_names]
|
||||
else:
|
||||
new_names = new_name_or_names
|
||||
|
||||
msg = (
|
||||
f"{old_name} is deprecated and will be removed in {removal_version}. "
|
||||
f"Use {', '.join(new_names)} instead."
|
||||
)
|
||||
logger.warning(msg)
|
||||
|
||||
for new_name in new_names:
|
||||
setattr(config, new_name, old_val)
|
||||
|
||||
|
||||
def get_from_deprecated_env_if_set(
|
||||
env_name: str,
|
||||
removal_version: str,
|
||||
field_name: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Get value from deprecated environment variable with warning.
|
||||
|
||||
Args:
|
||||
env_name: Name of the deprecated environment variable
|
||||
removal_version: Version when it will be removed
|
||||
field_name: Name of the field to suggest as alternative
|
||||
|
||||
Returns:
|
||||
The environment variable value if set, None otherwise
|
||||
"""
|
||||
if envs.is_set(env_name):
|
||||
value = os.environ.get(env_name)
|
||||
alt_msg = f" Please use {field_name} instead." if field_name else ""
|
||||
logger.warning_once(
|
||||
"Using %s environment variable is deprecated and will be removed in %s.%s",
|
||||
env_name,
|
||||
removal_version,
|
||||
alt_msg,
|
||||
)
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def set_from_deprecated_env_if_set(
|
||||
config: ConfigT,
|
||||
env_name: str,
|
||||
removal_version: str,
|
||||
field_name: str,
|
||||
to_bool: bool = False,
|
||||
to_int: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Set object field from deprecated environment variable with warning.
|
||||
|
||||
Args:
|
||||
config: Config object to set the field on
|
||||
env_name: Name of the deprecated environment variable
|
||||
removal_version: Version when the env var will be removed
|
||||
field_name: Name of the field to set
|
||||
to_bool: Whether to convert the environment variable value to boolean
|
||||
to_int: Whether to convert the environment variable value to integer
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if to_bool and to_int:
|
||||
raise ValueError("Cannot convert to both boolean and integer.")
|
||||
|
||||
env_value = get_from_deprecated_env_if_set(env_name, removal_version, field_name)
|
||||
if env_value is not None:
|
||||
field_value: str | bool | int = env_value
|
||||
if to_bool:
|
||||
field_value = env_value.lower() in ("1", "true")
|
||||
elif to_int:
|
||||
field_value = int(env_value)
|
||||
setattr(config, field_name, field_value)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user