Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
11
Dockerfile
11
Dockerfile
@@ -1,11 +1,12 @@
|
||||
FROM registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8
|
||||
ARG BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Keep the runtime stack from the known-good v8 image, but replace the
|
||||
# installed Python package with the repository's patched 0.16.1rc0 sources.
|
||||
WORKDIR /tmp
|
||||
WORKDIR /home
|
||||
|
||||
RUN rm -rf /usr/local/lib/python3.12/dist-packages/vllm \
|
||||
/usr/local/lib/python3.12/dist-packages/vllm-*.dist-info
|
||||
|
||||
COPY vllm /usr/local/lib/python3.12/dist-packages/vllm
|
||||
COPY vllm-0.16.1rc0+corex.4.4.0.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.16.1rc0+corex.4.4.0.dist-info
|
||||
COPY vllm-0.17.0+corex.20260420090923.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.17.0+corex.20260420090923.dist-info
|
||||
|
||||
ENTRYPOINT ["/bin/bash"]
|
||||
|
||||
69
README.md
69
README.md
@@ -1,62 +1,37 @@
|
||||
# bi_150-vllm
|
||||
|
||||
基于 `registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8` 的
|
||||
`vLLM 0.16.1rc0` 构建仓库,用于在 BI-V150 虚拟机环境中生成可直接运行的镜像。
|
||||
This repository contains the extracted `vLLM 0.17.0+corex.20260420090923`
|
||||
Python package used to overlay the vendor image
|
||||
`registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1`.
|
||||
|
||||
## 改动说明
|
||||
|
||||
本仓库只保留构建镜像所需的最小内容:
|
||||
## Included files
|
||||
|
||||
- `vllm/`
|
||||
当前运行代码
|
||||
- `vllm-0.16.1rc0+corex.4.4.0.dist-info/`
|
||||
对应的包元数据
|
||||
The Python package code copied from the image payload.
|
||||
- `vllm-0.17.0+corex.20260420090923.dist-info/`
|
||||
The package metadata extracted from the image.
|
||||
- `Dockerfile`
|
||||
构建最终镜像
|
||||
Builds a new image by replacing the installed `vllm` package in the vendor base image.
|
||||
|
||||
与基础镜像相比,本仓库保留的关键代码改动如下:
|
||||
## Build
|
||||
|
||||
- 在 `vllm/platforms/__init__.py` 中修复 CUDA 平台识别逻辑
|
||||
- 当 NVML 不可用且出现 `NVML Shared Library Not Found` 一类错误时
|
||||
不再直接判定为非 CUDA 平台
|
||||
- 改为回退到 `torch.cuda.is_available()` 和
|
||||
`torch.cuda.device_count()` 继续判断 CUDA 是否可用
|
||||
- 调整 CLI 初始化逻辑,避免 benchmark 可选依赖缺失时阻塞
|
||||
`vllm serve ...` 启动
|
||||
|
||||
这个修复用于解决如下启动失败:
|
||||
|
||||
```text
|
||||
RuntimeError: Failed to infer device type
|
||||
```
|
||||
|
||||
## 构建镜像
|
||||
|
||||
在仓库根目录执行:
|
||||
Run the following command from the repository root:
|
||||
|
||||
```bash
|
||||
docker build -t bi_150_vllm:0.16.1 .
|
||||
docker build --pull=false \
|
||||
--build-arg BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1 \
|
||||
-t bi_150_vllm:0.17.0 \
|
||||
.
|
||||
```
|
||||
|
||||
## 启动镜像
|
||||
## Verify
|
||||
|
||||
```bash
|
||||
docker run -dit \
|
||||
--name iluvatar_test \
|
||||
-p 38047:8000 \
|
||||
--privileged \
|
||||
-v /lib/modules:/lib/modules \
|
||||
-v /dev:/dev \
|
||||
-v /usr/src:/usr/src \
|
||||
-v /mnt/gpfs/leaderboard/modelHubXC/Amu/t1-1.5B:/model \
|
||||
-e CUDA_VISIBLE_DEVICES=0 \
|
||||
--entrypoint vllm \
|
||||
bi_150_vllm:0.16.1 \
|
||||
serve /model \
|
||||
--port 8000 \
|
||||
--served-model-name llm \
|
||||
--max-model-len 2048 \
|
||||
--enforce-eager \
|
||||
--trust-remote-code \
|
||||
-tp 1
|
||||
docker run --rm -it bi_150_vllm:0.17.0 \
|
||||
python3 -c "import vllm; print(vllm.__file__); print(vllm.__version__)"
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- This is an overlay-style repository, not the original upstream git source tree.
|
||||
- The Docker image keeps the vendor runtime stack and only replaces the Python package files.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
{"archive_info": {"hash": "sha256=f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1", "hashes": {"sha256": "f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1"}}, "url": "file:///workspace/vllm-0.16.1rc0%2Bcorex.4.4.0-py3-none-any.whl"}
|
||||
@@ -1,9 +1,9 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: vllm
|
||||
Version: 0.16.1rc0+corex.4.4.0
|
||||
Version: 0.17.0+corex.20260420090923
|
||||
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
|
||||
Author: vLLM Team
|
||||
License-Expression: Apache-2.0
|
||||
License: Apache-2.0
|
||||
Project-URL: Homepage, https://github.com/vllm-project/vllm
|
||||
Project-URL: Documentation, https://docs.vllm.ai/en/latest/
|
||||
Project-URL: Slack, https://slack.vllm.ai/
|
||||
@@ -23,7 +23,7 @@ Requires-Dist: regex
|
||||
Requires-Dist: cachetools
|
||||
Requires-Dist: psutil
|
||||
Requires-Dist: sentencepiece
|
||||
Requires-Dist: numpy==1.26.4
|
||||
Requires-Dist: numpy
|
||||
Requires-Dist: requests>=2.26.0
|
||||
Requires-Dist: tqdm
|
||||
Requires-Dist: blake3
|
||||
@@ -33,7 +33,7 @@ Requires-Dist: tokenizers>=0.21.1
|
||||
Requires-Dist: protobuf!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*,>=5.29.6
|
||||
Requires-Dist: fastapi[standard]>=0.115.0
|
||||
Requires-Dist: aiohttp>=3.13.3
|
||||
Requires-Dist: openai>=1.99.1
|
||||
Requires-Dist: openai<2.25.0,>=1.99.1
|
||||
Requires-Dist: pydantic>=2.12.0
|
||||
Requires-Dist: prometheus_client>=0.18.0
|
||||
Requires-Dist: pillow
|
||||
@@ -52,6 +52,7 @@ Requires-Dist: pyzmq>=25.0.0
|
||||
Requires-Dist: msgspec
|
||||
Requires-Dist: gguf>=0.17.0
|
||||
Requires-Dist: mistral_common[image]>=1.9.1
|
||||
Requires-Dist: opencv-python-headless>=4.13.0
|
||||
Requires-Dist: pyyaml
|
||||
Requires-Dist: six>=1.16.0; python_version > "3.11"
|
||||
Requires-Dist: setuptools<81.0.0,>=77.0.3; python_version > "3.11"
|
||||
@@ -76,6 +77,7 @@ Requires-Dist: opentelemetry-sdk>=1.27.0
|
||||
Requires-Dist: opentelemetry-api>=1.27.0
|
||||
Requires-Dist: opentelemetry-exporter-otlp>=1.27.0
|
||||
Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1
|
||||
Requires-Dist: kaldi-native-fbank>=1.18.7
|
||||
Requires-Dist: numba==0.61.2
|
||||
Requires-Dist: ray[cgraph]>=2.48.0
|
||||
Provides-Extra: bench
|
||||
@@ -84,6 +86,7 @@ Requires-Dist: matplotlib; extra == "bench"
|
||||
Requires-Dist: seaborn; extra == "bench"
|
||||
Requires-Dist: datasets; extra == "bench"
|
||||
Requires-Dist: scipy; extra == "bench"
|
||||
Requires-Dist: plotly; extra == "bench"
|
||||
Provides-Extra: tensorizer
|
||||
Requires-Dist: tensorizer==2.10.1; extra == "tensorizer"
|
||||
Provides-Extra: fastsafetensors
|
||||
@@ -97,6 +100,10 @@ Requires-Dist: soundfile; extra == "audio"
|
||||
Requires-Dist: mistral_common[audio]; extra == "audio"
|
||||
Provides-Extra: video
|
||||
Provides-Extra: flashinfer
|
||||
Provides-Extra: petit-kernel
|
||||
Requires-Dist: petit-kernel; extra == "petit-kernel"
|
||||
Provides-Extra: helion
|
||||
Requires-Dist: helion; extra == "helion"
|
||||
Provides-Extra: otel
|
||||
Requires-Dist: opentelemetry-sdk>=1.26.0; extra == "otel"
|
||||
Requires-Dist: opentelemetry-api>=1.26.0; extra == "otel"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: setuptools (82.0.0)
|
||||
Generator: setuptools (80.10.2)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
{"archive_info": {"hash": "sha256=844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740", "hashes": {"sha256": "844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740"}}, "url": "file:///home/poweruser/zrl/code/vllm_hub/vllm/build_pip/vllm-0.17.0%2Bcorex.20260420090923-py3-none-any.whl"}
|
||||
244
vllm/.gitignore
vendored
244
vllm/.gitignore
vendored
@@ -1,244 +0,0 @@
|
||||
# version file generated by setuptools-scm
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/__init__.py
|
||||
!vllm/vllm_flash_attn/flash_attn_interface.py
|
||||
|
||||
# OpenAI triton kernels copied from source
|
||||
vllm/third_party/triton_kernels/*
|
||||
|
||||
# FlashMLA interface copied from source
|
||||
vllm/third_party/flashmla/flash_mla_interface.py
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
cmake-build-*/
|
||||
CMakeUserPresets.json
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
/.deps/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# generated files
|
||||
**/generated/**
|
||||
|
||||
# uv
|
||||
uv.lock
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples/*
|
||||
!docs/examples/README.md
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# Claude
|
||||
.claude/
|
||||
|
||||
# Codex
|
||||
.codex/
|
||||
|
||||
# Cursor
|
||||
.cursor/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
# Results
|
||||
*.csv
|
||||
|
||||
# Python pickle files
|
||||
*.pkl
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ignore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
|
||||
# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
|
||||
!vllm/benchmarks/lib/
|
||||
|
||||
# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
|
||||
vllm/grpc/vllm_engine_pb2.py
|
||||
vllm/grpc/vllm_engine_pb2_grpc.py
|
||||
vllm/grpc/vllm_engine_pb2.pyi
|
||||
|
||||
# Ignore generated cpu headers
|
||||
csrc/cpu/cpu_attn_dispatch_generated.h
|
||||
|
||||
@@ -14,8 +14,6 @@ import typing
|
||||
import vllm.env_override # noqa: F401
|
||||
|
||||
MODULE_ATTRS = {
|
||||
"bc_linter_skip": "._bc_linter:bc_linter_skip",
|
||||
"bc_linter_include": "._bc_linter:bc_linter_include",
|
||||
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
|
||||
"EngineArgs": ".engine.arg_utils:EngineArgs",
|
||||
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||
@@ -62,8 +60,6 @@ if typing.TYPE_CHECKING:
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.executor.ray_utils import initialize_ray_cluster
|
||||
|
||||
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||
else:
|
||||
|
||||
def __getattr__(name: str) -> typing.Any:
|
||||
@@ -79,8 +75,6 @@ else:
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"bc_linter_skip",
|
||||
"bc_linter_include",
|
||||
"__version_tuple__",
|
||||
"LLM",
|
||||
"ModelRegistry",
|
||||
|
||||
@@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl(
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
bias1: torch.Tensor | None = None,
|
||||
bias2: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
@@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl(
|
||||
a2_scale,
|
||||
num_local_tokens=num_local_tokens,
|
||||
dtype=output_dtype,
|
||||
hidden_pad=hidden_pad,
|
||||
intermediate_pad=intermediate_pad,
|
||||
bias1=bias1,
|
||||
bias2=bias2,
|
||||
)
|
||||
|
||||
|
||||
@@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake(
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_fused_topk_impl(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
gate_up: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.fused_moe import fused_topk
|
||||
|
||||
# fused_topk returns (topk_weights, topk_indices)
|
||||
return fused_topk(x, router_logits, top_k, gate_up)
|
||||
|
||||
|
||||
def _rocm_aiter_fused_topk_fake(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
gate_up: bool,
|
||||
) -> None:
|
||||
# tuple[torch.Tensor, torch.Tensor]:
|
||||
pass
|
||||
|
||||
|
||||
# Cache whether aiter supports FP8 MLA parameters
|
||||
_AITER_MLA_SUPPORTS_FP8: bool | None = None
|
||||
|
||||
@@ -994,6 +1024,70 @@ class rocm_aiter_ops:
|
||||
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
|
||||
@staticmethod
|
||||
def get_aiter_activation_type(activation_str: str):
|
||||
"""
|
||||
Given an activation type as a string, returns the corresponding aiter ActivationType enum.
|
||||
Supported activation types: "no", "none", "silu", "gelu", "swiglu".
|
||||
Returns None if the mapping fails.
|
||||
|
||||
Args:
|
||||
activation_str (str): Activation type as string.
|
||||
|
||||
Returns:
|
||||
Aiter ActivationType enum value, or None if not found.
|
||||
"""
|
||||
# Import only locally, since aiter may not always be available.
|
||||
try:
|
||||
from aiter import ActivationType
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not isinstance(activation_str, str):
|
||||
return None
|
||||
|
||||
name = activation_str.strip().lower()
|
||||
mapping = {
|
||||
"none": ActivationType.No,
|
||||
"no": ActivationType.No,
|
||||
"silu": ActivationType.Silu,
|
||||
"gelu": ActivationType.Gelu,
|
||||
"swiglu": ActivationType.Swiglu,
|
||||
}
|
||||
return mapping.get(name)
|
||||
|
||||
@staticmethod
|
||||
def get_aiter_quant_type(quant_type_str: str):
|
||||
"""
|
||||
Given a quantization type as a string, returns the corresponding aiter QuantType enum.
|
||||
Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128".
|
||||
Returns None if the mapping fails.
|
||||
|
||||
Args:
|
||||
quant_type_str (str): Quantization type as string.
|
||||
|
||||
Returns:
|
||||
Aiter QuantType enum value, or None if not found.
|
||||
"""
|
||||
try:
|
||||
from aiter import QuantType
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if not isinstance(quant_type_str, str):
|
||||
return None
|
||||
|
||||
name = quant_type_str.strip().lower()
|
||||
mapping = {
|
||||
"no": QuantType.No,
|
||||
"per_tensor": QuantType.per_Tensor,
|
||||
"per_token": QuantType.per_Token,
|
||||
"per_1x32": QuantType.per_1x32,
|
||||
"per_1x128": QuantType.per_1x128,
|
||||
"per_128x128": QuantType.per_128x128,
|
||||
}
|
||||
return mapping.get(name)
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_enabled(cls) -> bool:
|
||||
@@ -1127,6 +1221,14 @@ class rocm_aiter_ops:
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_fused_topk",
|
||||
op_func=_rocm_aiter_fused_topk_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_fused_topk_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_mla_decode_fwd",
|
||||
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
||||
@@ -1360,6 +1462,10 @@ class rocm_aiter_ops:
|
||||
a2_scale: torch.Tensor | None = None,
|
||||
num_local_tokens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype | None = None,
|
||||
hidden_pad: int = 0,
|
||||
intermediate_pad: int = 0,
|
||||
bias1: torch.Tensor | None = None,
|
||||
bias2: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_fused_moe(
|
||||
hidden_states,
|
||||
@@ -1377,6 +1483,10 @@ class rocm_aiter_ops:
|
||||
a2_scale,
|
||||
num_local_tokens,
|
||||
output_dtype,
|
||||
hidden_pad,
|
||||
intermediate_pad,
|
||||
bias1,
|
||||
bias2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1481,6 +1591,15 @@ class rocm_aiter_ops:
|
||||
routed_scaling_factor,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def fused_topk(
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
gate_up: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up)
|
||||
|
||||
@staticmethod
|
||||
def mla_decode_fwd(
|
||||
q: torch.Tensor,
|
||||
@@ -1701,6 +1820,47 @@ class rocm_aiter_ops:
|
||||
|
||||
return shuffle_weight(tensor, layout=layout)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weight_a16w4(
|
||||
tensor: "torch.Tensor",
|
||||
nLane: int,
|
||||
gate_up: bool,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Shuffles the weight tensor into (A16W4) layout for AITER kernels.
|
||||
|
||||
Args:
|
||||
tensor: The input weight tensor to be shuffled.
|
||||
layout: The block layout to use, defaults to (16, 4).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shuffled tensor.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_weight_a16w4
|
||||
|
||||
return shuffle_weight_a16w4(tensor, nLane, gate_up)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_scale_a16w4(
|
||||
tensor: "torch.Tensor",
|
||||
num_experts: int,
|
||||
gate_up: bool,
|
||||
) -> "torch.Tensor":
|
||||
"""
|
||||
Shuffles the scale tensor into (A16W4) layout for AITER kernels.
|
||||
|
||||
Args:
|
||||
tensor: The input scale tensor to be shuffled.
|
||||
num_experts: Number of experts, needed for reshaping logic.
|
||||
gate_up: Whether the scale is for w13 (True) or w2 (False).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The shuffled scale tensor.
|
||||
"""
|
||||
from aiter.ops.shuffle import shuffle_scale_a16w4
|
||||
|
||||
return shuffle_scale_a16w4(tensor, num_experts, gate_up)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_weights(
|
||||
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# vllm/_bc_linter.py
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(obj: T) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||
|
||||
|
||||
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
No-op decorator to mark symbols/files for BC-linter suppression.
|
||||
|
||||
Usage:
|
||||
@bc_linter_skip
|
||||
def legacy_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(obj: T) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||
|
||||
|
||||
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
Usage:
|
||||
@bc_linter_include
|
||||
def public_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
__all__ = ["bc_linter_skip", "bc_linter_include"]
|
||||
1683
vllm/_custom_ops.py
1683
vllm/_custom_ops.py
File diff suppressed because it is too large
Load Diff
@@ -31,6 +31,7 @@ from tempfile import NamedTemporaryFile
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -60,6 +61,8 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_NUM_PROMPTS = 1000
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Data Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -303,9 +306,11 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
|
||||
a dictionary with the image as a base64 data URL.
|
||||
|
||||
3. String input: - Treats the string as a URL or local file path. -
|
||||
Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://". - Returns a dictionary with the image URL.
|
||||
3. String input: - Treats the string as a URL, local file path, or base64
|
||||
encoded data. - If string starts with "data:image/", treats as base64.
|
||||
- If string starts with "http://", "https://", or "file://", treats as URL.
|
||||
- Otherwise treats as local file path and prepends "file://".
|
||||
- Returns a dictionary with the image URL or base64 data.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
@@ -325,14 +330,14 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
if isinstance(image, str):
|
||||
image_url = (
|
||||
image
|
||||
if image.startswith(("http://", "https://", "file://"))
|
||||
if image.startswith(("http://", "https://", "file://", "data:image/"))
|
||||
else f"file://{image}"
|
||||
)
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||
" or str or dictionary with raw image bytes."
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image, "
|
||||
"str (URL, file path, or base64 data URL), or dictionary with raw image bytes."
|
||||
)
|
||||
|
||||
|
||||
@@ -1338,7 +1343,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
default=DEFAULT_NUM_PROMPTS,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -2676,6 +2681,14 @@ class MMVUDataset(HuggingFaceDataset):
|
||||
+ (" ".join(f"{k}.{v}" for k, v in x["choices"].items())),
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._remote_path_root = (
|
||||
f"https://huggingface.co/datasets/{self.hf_name}/resolve/main"
|
||||
)
|
||||
self._local_path_root = snapshot_download(self.hf_name, repo_type="dataset")
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
@@ -2698,7 +2711,9 @@ class MMVUDataset(HuggingFaceDataset):
|
||||
break
|
||||
|
||||
prompt = parser_fn(item)
|
||||
mm_content = process_video(item["video"])
|
||||
mm_content = process_video(
|
||||
item["video"].replace(self._remote_path_root, self._local_path_root)
|
||||
)
|
||||
prompt_len = len(tokenizer.encode(prompt))
|
||||
if enable_multimodal_chat:
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
|
||||
3
vllm/benchmarks/lib/__init__.py
Normal file
3
vllm/benchmarks/lib/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark library utilities."""
|
||||
802
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
802
vllm/benchmarks/lib/endpoint_request_func.py
Normal file
@@ -0,0 +1,802 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""The request function for API endpoints."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
import aiohttp
|
||||
import regex as re
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
class StreamedResponseHandler:
|
||||
"""Handles streaming HTTP responses by accumulating chunks until complete
|
||||
messages are available."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
|
||||
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
|
||||
"""Add a chunk of bytes to the buffer and return any complete
|
||||
messages."""
|
||||
chunk_str = chunk_bytes.decode("utf-8")
|
||||
self.buffer += chunk_str
|
||||
|
||||
messages = []
|
||||
|
||||
# Split by double newlines (SSE message separator)
|
||||
while "\n\n" in self.buffer:
|
||||
message, self.buffer = self.buffer.split("\n\n", 1)
|
||||
message = message.strip()
|
||||
if message:
|
||||
messages.append(message)
|
||||
|
||||
# if self.buffer is not empty, check if it is a complete message
|
||||
# by removing data: prefix and check if it is a valid JSON
|
||||
if self.buffer.startswith("data: "):
|
||||
message_content = self.buffer.removeprefix("data: ").strip()
|
||||
if message_content == "[DONE]":
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
elif message_content:
|
||||
try:
|
||||
json.loads(message_content)
|
||||
messages.append(self.buffer.strip())
|
||||
self.buffer = ""
|
||||
except json.JSONDecodeError:
|
||||
# Incomplete JSON, wait for more chunks.
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
"""The input for the request function."""
|
||||
|
||||
prompt: str | list[str]
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: str | None = None
|
||||
logprobs: int | None = None
|
||||
extra_headers: dict | None = None
|
||||
extra_body: dict | None = None
|
||||
multi_modal_content: dict | list[dict] | None = None
|
||||
ignore_eos: bool = False
|
||||
language: str | None = None
|
||||
request_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
"""The output of the request function including metrics."""
|
||||
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: list[float] = field(default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
start_time: float = 0.0
|
||||
input_audio_duration: float = 0.0 # in seconds
|
||||
|
||||
|
||||
class RequestFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> Awaitable[RequestFuncOutput]: ...
|
||||
|
||||
|
||||
def _validate_api_url(
|
||||
api_url: str,
|
||||
api_name: str,
|
||||
expected_suffixes: str | set[str],
|
||||
) -> None:
|
||||
if isinstance(expected_suffixes, str):
|
||||
expected_suffixes = {expected_suffixes}
|
||||
|
||||
expected_suffixes = {*expected_suffixes, "profile"}
|
||||
|
||||
if not api_url.endswith(tuple(expected_suffixes)):
|
||||
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
|
||||
|
||||
|
||||
def _update_payload_common(
|
||||
payload: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
|
||||
|
||||
def _update_headers_common(
|
||||
headers: dict[str, Any],
|
||||
request_func_input: RequestFuncInput,
|
||||
) -> None:
|
||||
if request_func_input.extra_headers:
|
||||
headers |= request_func_input.extra_headers
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
|
||||
def _get_headers(content_type: str | None = None) -> dict[str, str]:
|
||||
headers = {}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""The async request function for the OpenAI Completions API.
|
||||
|
||||
Args:
|
||||
request_func_input: The input for the request function.
|
||||
pbar: The progress bar to display the progress.
|
||||
|
||||
Returns:
|
||||
The output of the request function.
|
||||
"""
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Completions API", "completions")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"repetition_penalty": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers()
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
first_chunk_received = False
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
first_chunk_received = True
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
output.error = (
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!"
|
||||
)
|
||||
output.generated_text = generated_text
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def _get_chat_content(
|
||||
request_func_input: RequestFuncInput,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> list[dict[str, Any]]:
|
||||
text_contents = [{"type": "text", "text": request_func_input.prompt}]
|
||||
|
||||
mm_contents = []
|
||||
if request_func_input.multi_modal_content:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
if isinstance(mm_content, list):
|
||||
mm_contents.extend(request_func_input.multi_modal_content)
|
||||
elif isinstance(mm_content, dict):
|
||||
mm_contents.append(request_func_input.multi_modal_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
||||
)
|
||||
|
||||
if mm_position == "first":
|
||||
return mm_contents + text_contents
|
||||
|
||||
return text_contents + mm_contents
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
# NOTE: SSE comments (often used as pings) start with
|
||||
# a colon. These are not JSON data payload and should
|
||||
# be skipped.
|
||||
if message.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = message.removeprefix("data: ")
|
||||
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get("completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
import soundfile
|
||||
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
|
||||
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"language": "en",
|
||||
# Flattened due to multipart/form-data
|
||||
"stream_include_usage": True,
|
||||
"stream_continuous_usage_stats": True,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers()
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
buffer = io.BytesIO()
|
||||
soundfile.write(buffer, y, sr, format="WAV")
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
mm_audio = request_func_input.multi_modal_content
|
||||
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
||||
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
||||
with to_bytes(*mm_audio["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
form.add_field(key, str(value))
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output.input_audio_duration = soundfile.info(f).duration
|
||||
f.seek(0)
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, data=form, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
handler = StreamedResponseHandler()
|
||||
|
||||
async for chunk_bytes in response.content.iter_any():
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
messages = handler.add_chunk(chunk_bytes)
|
||||
for message in messages:
|
||||
if type(message) is bytes:
|
||||
message = message.decode("utf-8")
|
||||
chunk = message.removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(
|
||||
timestamp - most_recent_timestamp
|
||||
)
|
||||
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens"
|
||||
)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def _run_pooling_request(
|
||||
session: aiohttp.ClientSession,
|
||||
api_url: str,
|
||||
payload: dict[str, Any],
|
||||
headers: dict[str, Any],
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
output = RequestFuncOutput()
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
try:
|
||||
async with session.post(url=api_url, headers=headers, json=payload) as response:
|
||||
if response.status == 200:
|
||||
output.ttft = output.latency = time.perf_counter() - st
|
||||
|
||||
if payload.get("encoding_format", "float") == "bytes":
|
||||
metadata = json.loads(response.headers["metadata"])
|
||||
usage = metadata.get("usage", {})
|
||||
else:
|
||||
data = await response.json()
|
||||
usage = data.get("usage", {})
|
||||
|
||||
output.success = True
|
||||
output.generated_text = ""
|
||||
output.prompt_len = usage.get("prompt_tokens", 0)
|
||||
else:
|
||||
output.success = False
|
||||
output.error = response.reason or ""
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"input": request_func_input.prompt,
|
||||
# Many embedding models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_vllm_rerank(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "vLLM score API", "rerank")
|
||||
|
||||
assert (
|
||||
isinstance(request_func_input.prompt, list)
|
||||
and len(request_func_input.prompt) > 1
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"query": request_func_input.prompt[0],
|
||||
"documents": request_func_input.prompt[1:],
|
||||
# Many reranker models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_chat(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
mm_position: Literal["first", "last"] = "last",
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
|
||||
|
||||
content = _get_chat_content(request_func_input, mm_position=mm_position)
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"messages": [
|
||||
{"role": "user", "content": content},
|
||||
],
|
||||
# Many embedding models have short context length,
|
||||
# this is to avoid dropping some of the requests.
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
def _try_extract_request_idx(request_func_input: RequestFuncInput):
|
||||
if request_func_input.request_id:
|
||||
match = re.search(r"(\d+)$", request_func_input.request_id)
|
||||
if match:
|
||||
try:
|
||||
return int(match.group(1))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _preprocess_clip(request_func_input: RequestFuncInput):
|
||||
if request_func_input.multi_modal_content:
|
||||
# Image input
|
||||
request_func_input.prompt = ""
|
||||
|
||||
|
||||
def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
|
||||
if request_func_input.multi_modal_content:
|
||||
request_idx = _try_extract_request_idx(request_func_input)
|
||||
|
||||
# Adjust the ratio manually if needed.
|
||||
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
|
||||
|
||||
if use_image_only_prompt:
|
||||
# Image input
|
||||
request_func_input.prompt = "Represent the given image."
|
||||
else:
|
||||
# Text+Image input
|
||||
request_func_input.prompt = (
|
||||
f"Represent the given image with the following question: "
|
||||
f"{request_func_input.prompt}"
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_clip(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_clip(request_func_input)
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_openai_embeddings_vlm2vec(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_vlm2vec(request_func_input)
|
||||
|
||||
return await async_request_openai_embeddings_chat(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
mm_position="first",
|
||||
)
|
||||
|
||||
|
||||
async def async_request_infinity_embeddings(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
}
|
||||
|
||||
if request_func_input.prompt:
|
||||
payload["input"] = request_func_input.prompt
|
||||
else:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
assert isinstance(mm_content, dict)
|
||||
|
||||
mm_type = mm_content["type"]
|
||||
payload["input"] = mm_content[mm_type]["url"]
|
||||
payload["modality"] = mm_type.split("_", 1)[0]
|
||||
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_infinity_embeddings_clip(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
_preprocess_clip(request_func_input)
|
||||
|
||||
return await async_request_infinity_embeddings(
|
||||
request_func_input,
|
||||
session,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
async def async_request_vllm_pooling(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: tqdm | None = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
_validate_api_url(api_url, "vLLM Pooling API", "pooling")
|
||||
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
else request_func_input.model,
|
||||
"truncate_prompt_tokens": -1,
|
||||
}
|
||||
|
||||
payload = payload | request_func_input.prompt
|
||||
|
||||
_update_payload_common(payload, request_func_input)
|
||||
|
||||
headers = _get_headers("application/json")
|
||||
_update_headers_common(headers, request_func_input)
|
||||
|
||||
return await _run_pooling_request(
|
||||
session,
|
||||
api_url,
|
||||
payload=payload,
|
||||
headers=headers,
|
||||
pbar=pbar,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"vllm": async_request_openai_completions,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"openai-audio": async_request_openai_audio,
|
||||
"openai-embeddings": async_request_openai_embeddings,
|
||||
"openai-embeddings-chat": async_request_openai_embeddings_chat,
|
||||
"openai-embeddings-clip": async_request_openai_embeddings_clip,
|
||||
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
|
||||
# Infinity embedding server: https://github.com/michaelfeil/infinity
|
||||
"infinity-embeddings": async_request_infinity_embeddings,
|
||||
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
|
||||
# (Infinity embedding server does not support vlm2vec)
|
||||
"vllm-pooling": async_request_vllm_pooling,
|
||||
"vllm-rerank": async_request_vllm_rerank,
|
||||
}
|
||||
|
||||
OPENAI_COMPATIBLE_BACKENDS = [
|
||||
k
|
||||
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
||||
]
|
||||
79
vllm/benchmarks/lib/ready_checker.py
Normal file
79
vllm/benchmarks/lib/ready_checker.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for checking endpoint readiness."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def wait_for_endpoint(
|
||||
request_func: RequestFunc,
|
||||
test_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
timeout_seconds: int = 600,
|
||||
retry_interval: int = 5,
|
||||
) -> RequestFuncOutput:
|
||||
"""
|
||||
Wait for an endpoint to become available before starting benchmarks.
|
||||
|
||||
Args:
|
||||
request_func: The async request function to call
|
||||
test_input: The RequestFuncInput to test with
|
||||
timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
|
||||
retry_interval: Time between retries in seconds (default: 5 seconds)
|
||||
|
||||
Returns:
|
||||
RequestFuncOutput: The successful response
|
||||
|
||||
Raises:
|
||||
ValueError: If the endpoint doesn't become available within the timeout
|
||||
"""
|
||||
deadline = time.perf_counter() + timeout_seconds
|
||||
output = RequestFuncOutput(success=False)
|
||||
print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
|
||||
|
||||
with tqdm(
|
||||
total=timeout_seconds,
|
||||
bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
|
||||
unit="s",
|
||||
) as pbar:
|
||||
while True:
|
||||
# update progress bar
|
||||
remaining = deadline - time.perf_counter()
|
||||
elapsed = timeout_seconds - remaining
|
||||
update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
|
||||
pbar.update(update_amount)
|
||||
pbar.refresh()
|
||||
if remaining <= 0:
|
||||
pbar.close()
|
||||
break
|
||||
|
||||
# ping the endpoint using request_func
|
||||
try:
|
||||
output = await request_func(
|
||||
request_func_input=test_input, session=session
|
||||
)
|
||||
if output.success:
|
||||
pbar.close()
|
||||
return output
|
||||
else:
|
||||
err_last_line = str(output.error).rstrip().rsplit("\n", 1)[-1]
|
||||
logger.warning("Endpoint is not ready. Error='%s'", err_last_line)
|
||||
except aiohttp.ClientConnectorError:
|
||||
pass
|
||||
|
||||
# retry after a delay
|
||||
sleep_duration = min(retry_interval, remaining)
|
||||
if sleep_duration > 0:
|
||||
await asyncio.sleep(sleep_duration)
|
||||
|
||||
return output
|
||||
131
vllm/benchmarks/lib/utils.py
Normal file
131
vllm/benchmarks/lib/utils.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
def extract_field(
|
||||
args: argparse.Namespace, extra_info: dict[str, Any], field_name: str
|
||||
) -> str:
|
||||
if field_name in extra_info:
|
||||
return extra_info[field_name]
|
||||
|
||||
v = args
|
||||
# For example, args.compilation_config.mode
|
||||
for nested_field in field_name.split("."):
|
||||
if not hasattr(v, nested_field):
|
||||
return ""
|
||||
v = getattr(v, nested_field)
|
||||
return v
|
||||
|
||||
|
||||
def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if the benchmark is run with torch.compile
|
||||
"""
|
||||
return not (
|
||||
extract_field(args, extra_info, "compilation_config.mode") == "0"
|
||||
or "eager" in getattr(args, "output_json", "")
|
||||
or "eager" in getattr(args, "result_filename", "")
|
||||
)
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
|
||||
) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
||||
"""
|
||||
records = []
|
||||
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
|
||||
return records
|
||||
|
||||
for name, benchmark_values in metrics.items():
|
||||
if not isinstance(benchmark_values, list):
|
||||
raise TypeError(
|
||||
f"benchmark_values for metric '{name}' must be a list, "
|
||||
f"but got {type(benchmark_values).__name__}"
|
||||
)
|
||||
|
||||
record = {
|
||||
"benchmark": {
|
||||
"name": "vLLM benchmark",
|
||||
"extra_info": {
|
||||
"args": vars(args),
|
||||
"compilation_config.mode": extract_field(
|
||||
args, extra_info, "compilation_config.mode"
|
||||
),
|
||||
"optimization_level": extract_field(
|
||||
args, extra_info, "optimization_level"
|
||||
),
|
||||
# A boolean field used by vLLM benchmark HUD dashboard
|
||||
"use_compile": use_compile(args, extra_info),
|
||||
},
|
||||
},
|
||||
"model": {
|
||||
"name": args.model,
|
||||
},
|
||||
"metric": {
|
||||
"name": name,
|
||||
"benchmark_values": benchmark_values,
|
||||
"extra_info": extra_info,
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
|
||||
extra_info["tensor_parallel_size"]
|
||||
)
|
||||
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
def clear_inf(self, o: Any):
|
||||
if isinstance(o, dict):
|
||||
return {
|
||||
str(k)
|
||||
if not isinstance(k, (str, int, float, bool, type(None)))
|
||||
else k: self.clear_inf(v)
|
||||
for k, v in o.items()
|
||||
}
|
||||
elif isinstance(o, list):
|
||||
return [self.clear_inf(v) for v in o]
|
||||
elif isinstance(o, float) and math.isinf(o):
|
||||
return "inf"
|
||||
return o
|
||||
|
||||
def iterencode(self, o: Any, *args, **kwargs) -> Any:
|
||||
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
||||
|
||||
|
||||
def write_to_json(filename: str, records: list) -> None:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(
|
||||
records,
|
||||
f,
|
||||
cls=InfEncoder,
|
||||
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def default_vllm_config():
|
||||
"""Set a default VllmConfig for cases that directly test CustomOps or pathways
|
||||
that use get_current_vllm_config() outside of a full engine context.
|
||||
"""
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
yield
|
||||
316
vllm/benchmarks/plot.py
Normal file
316
vllm/benchmarks/plot.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Generate plots for benchmark results."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import plotly.express as px
|
||||
import plotly.io as pio
|
||||
except ImportError:
|
||||
_plotly = PlaceholderModule("plotly")
|
||||
px = _plotly.placeholder_attr("express")
|
||||
pio = _plotly.placeholder_attr("io")
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
_matplotlib = PlaceholderModule("matplotlib")
|
||||
plt = _matplotlib.placeholder_attr("pyplot")
|
||||
|
||||
|
||||
def generate_timeline_plot(
|
||||
results: list[dict[str, Any]],
|
||||
output_path: Path,
|
||||
colors: list[str] | None = None,
|
||||
itl_thresholds: list[float] | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate an HTML timeline plot from benchmark results.
|
||||
|
||||
Args:
|
||||
results: List of per-request result dictionaries containing:
|
||||
- start_time: Request start time (seconds)
|
||||
- ttft: Time to first token (seconds)
|
||||
- itl: List of inter-token latencies (seconds)
|
||||
- latency: Total request latency (seconds)
|
||||
- prompt_len: Number of prompt tokens
|
||||
- output_tokens: Number of output tokens
|
||||
output_path: Path where the HTML file will be saved
|
||||
colors: List of colors for ITL categories (default: green, orange, red, black)
|
||||
itl_thresholds: ITL thresholds in seconds (default: [1.0, 4.0, 6.0])
|
||||
labels: Labels for ITL categories (default based on thresholds)
|
||||
"""
|
||||
|
||||
# Set defaults
|
||||
if colors is None:
|
||||
colors = ["#109618", "#FF7F0E", "#D62728"]
|
||||
if itl_thresholds is None:
|
||||
itl_thresholds = [0.025, 0.050]
|
||||
if labels is None:
|
||||
labels = [
|
||||
f"ITL < {itl_thresholds[0] * 1000:.0f}ms",
|
||||
f"{itl_thresholds[0] * 1000:.0f}ms ≤ ITL < {itl_thresholds[1] * 1000:.0f}ms", # noqa
|
||||
f"ITL ≥ {itl_thresholds[1] * 1000:.0f}ms",
|
||||
]
|
||||
|
||||
labels_colors = {"TTFT": "#636EFA", **dict(zip(labels, colors))}
|
||||
labels_order = ["TTFT"] + labels
|
||||
|
||||
timeline_data = construct_timeline_data(results, itl_thresholds, labels)
|
||||
|
||||
if not timeline_data:
|
||||
print("No timeline data to plot")
|
||||
return
|
||||
|
||||
# Create the plot
|
||||
fig = px.timeline(
|
||||
timeline_data,
|
||||
x_start="start",
|
||||
x_end="end",
|
||||
y="request_id",
|
||||
color="type",
|
||||
color_discrete_map=labels_colors,
|
||||
category_orders={"type": labels_order},
|
||||
hover_data=[
|
||||
"prompt_tokens",
|
||||
"output_tokens",
|
||||
"req_start_time",
|
||||
"req_finish_time",
|
||||
"segment_start",
|
||||
"segment_end",
|
||||
"duration",
|
||||
],
|
||||
)
|
||||
|
||||
# Customize hover template to show only time without date
|
||||
fig.update_traces(
|
||||
hovertemplate="<b>%{y}</b><br>"
|
||||
"Type: %{fullData.name}<br>"
|
||||
"Start: %{customdata[4]}<br>"
|
||||
"End: %{customdata[5]}<br>"
|
||||
"Duration: %{customdata[6]}<br>"
|
||||
"Prompt Tokens: %{customdata[0]}<br>"
|
||||
"Output Tokens: %{customdata[1]}<br>"
|
||||
"Request Start Time: %{customdata[2]}<br>"
|
||||
"Request End Time: %{customdata[3]}<br>"
|
||||
"<extra></extra>"
|
||||
)
|
||||
|
||||
fig.update_yaxes(autorange="reversed")
|
||||
fig.update_layout(
|
||||
xaxis_title="Time",
|
||||
yaxis_title="Request ID",
|
||||
showlegend=True,
|
||||
)
|
||||
|
||||
# Save to HTML
|
||||
pio.write_html(fig, str(output_path))
|
||||
print(f"Timeline plot saved to: {output_path}")
|
||||
|
||||
|
||||
def construct_timeline_data(
|
||||
requests_data: list[dict[str, Any]],
|
||||
itl_thresholds: list[float],
|
||||
labels: list[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Construct timeline data from request results.
|
||||
|
||||
Args:
|
||||
requests_data: List of per-request result dictionaries
|
||||
itl_thresholds: ITL thresholds in seconds
|
||||
labels: Labels for ITL categories
|
||||
|
||||
Returns:
|
||||
List of timeline segments for plotting
|
||||
"""
|
||||
|
||||
def tostr(sec_time: float) -> str:
|
||||
"""Convert seconds to HH:MM:SS.mmm format."""
|
||||
h = int(sec_time // 3600)
|
||||
assert h < 100, "time seems to last more than 100 hours"
|
||||
m = int((sec_time % 3600) // 60)
|
||||
s = sec_time % 60
|
||||
return f"{h:02d}:{m:02d}:{s:06.3f}"
|
||||
|
||||
def itl_type(itl: float) -> str:
|
||||
"""Categorize ITL based on thresholds."""
|
||||
if itl < itl_thresholds[0]:
|
||||
return labels[0]
|
||||
elif itl < itl_thresholds[1]:
|
||||
return labels[1]
|
||||
else:
|
||||
return labels[2]
|
||||
|
||||
# Find the earliest start time to use as t0
|
||||
t0 = None
|
||||
for request in requests_data:
|
||||
start_time = request.get("start_time")
|
||||
if start_time is not None and (t0 is None or start_time < t0):
|
||||
t0 = start_time
|
||||
|
||||
if t0 is None:
|
||||
return []
|
||||
|
||||
timeline_data = []
|
||||
|
||||
for i, request in enumerate(requests_data):
|
||||
start_time = request.get("start_time")
|
||||
ttft = request.get("ttft")
|
||||
itl = request.get("itl", [])
|
||||
latency = request.get("latency")
|
||||
prompt_len = request.get("prompt_len", 0)
|
||||
output_tokens = request.get("output_tokens", 0)
|
||||
|
||||
# Skip requests without required data
|
||||
if start_time is None or ttft is None or latency is None:
|
||||
continue
|
||||
|
||||
# Normalize start time
|
||||
start_time = start_time - t0
|
||||
start_time_str = tostr(start_time)
|
||||
|
||||
# TTFT segment
|
||||
ttft_end = start_time + ttft
|
||||
ttft_end_str = tostr(ttft_end)
|
||||
|
||||
timeline_data.append(
|
||||
{
|
||||
"request_id": f"Req {i}",
|
||||
"start": start_time_str,
|
||||
"end": ttft_end_str,
|
||||
"type": "TTFT",
|
||||
"prompt_tokens": prompt_len,
|
||||
"output_tokens": output_tokens,
|
||||
"req_start_time": tostr(start_time),
|
||||
"req_finish_time": tostr(start_time + latency),
|
||||
"segment_start": start_time_str,
|
||||
"segment_end": ttft_end_str,
|
||||
"duration": f"{ttft:.3f}s",
|
||||
}
|
||||
)
|
||||
|
||||
# ITL segments
|
||||
prev_time = ttft_end
|
||||
prev_time_str = ttft_end_str
|
||||
|
||||
for itl_value in itl:
|
||||
itl_end = prev_time + itl_value
|
||||
itl_end_str = tostr(itl_end)
|
||||
|
||||
timeline_data.append(
|
||||
{
|
||||
"request_id": f"Req {i}",
|
||||
"start": prev_time_str,
|
||||
"end": itl_end_str,
|
||||
"type": itl_type(itl_value),
|
||||
"prompt_tokens": prompt_len,
|
||||
"output_tokens": output_tokens,
|
||||
"req_start_time": tostr(start_time),
|
||||
"req_finish_time": tostr(start_time + latency),
|
||||
"segment_start": prev_time_str,
|
||||
"segment_end": itl_end_str,
|
||||
"duration": f"{itl_value:.3f}s",
|
||||
}
|
||||
)
|
||||
|
||||
prev_time = itl_end
|
||||
prev_time_str = itl_end_str
|
||||
|
||||
return timeline_data
|
||||
|
||||
|
||||
def generate_dataset_stats_plot(
|
||||
results: list[dict[str, Any]],
|
||||
output_path: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Generate a matplotlib figure with dataset statistics.
|
||||
|
||||
Creates a figure with 4 subplots:
|
||||
- Top-left: Prompt tokens distribution (histogram)
|
||||
- Top-right: Output tokens distribution (histogram)
|
||||
- Bottom-left: Prompt+output tokens distribution (histogram)
|
||||
- Bottom-right: Stacked bar chart (request_id vs tokens)
|
||||
|
||||
Args:
|
||||
results: List of per-request result dictionaries containing:
|
||||
- prompt_len: Number of prompt tokens
|
||||
- output_tokens: Number of output tokens
|
||||
output_path: Path where the figure will be saved
|
||||
"""
|
||||
# Extract data
|
||||
prompt_tokens = []
|
||||
output_tokens = []
|
||||
total_tokens = []
|
||||
|
||||
for request in results:
|
||||
prompt_len = request.get("prompt_len", 0)
|
||||
output_len = request.get("output_tokens", 0)
|
||||
|
||||
prompt_tokens.append(prompt_len)
|
||||
output_tokens.append(output_len)
|
||||
total_tokens.append(prompt_len + output_len)
|
||||
|
||||
if not prompt_tokens:
|
||||
print("No data available for dataset statistics plot")
|
||||
return
|
||||
|
||||
# Create figure with 4 subplots
|
||||
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
|
||||
|
||||
# Top-left: Prompt tokens distribution
|
||||
ax1.hist(prompt_tokens, bins=30, color="steelblue", edgecolor="black", alpha=0.7)
|
||||
ax1.set_xlabel("Prompt Tokens")
|
||||
ax1.set_ylabel("Frequency")
|
||||
ax1.set_title("Prompt Tokens Distribution")
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# Top-right: Output tokens distribution
|
||||
ax2.hist(output_tokens, bins=30, color="coral", edgecolor="black", alpha=0.7)
|
||||
ax2.set_xlabel("Output Tokens")
|
||||
ax2.set_ylabel("Frequency")
|
||||
ax2.set_title("Output Tokens Distribution")
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
# Bottom-left: Prompt+output tokens distribution
|
||||
ax3.hist(
|
||||
total_tokens, bins=30, color="mediumseagreen", edgecolor="black", alpha=0.7
|
||||
)
|
||||
ax3.set_xlabel("Total Tokens (Prompt + Output)")
|
||||
ax3.set_ylabel("Frequency")
|
||||
ax3.set_title("Total Tokens Distribution")
|
||||
ax3.grid(True, alpha=0.3)
|
||||
|
||||
# Bottom-right: Stacked bar chart
|
||||
request_ids = list(range(len(prompt_tokens)))
|
||||
ax4.bar(
|
||||
request_ids, prompt_tokens, label="Prompt Tokens", color="steelblue", alpha=0.7
|
||||
)
|
||||
ax4.bar(
|
||||
request_ids,
|
||||
output_tokens,
|
||||
bottom=prompt_tokens,
|
||||
label="Output Tokens",
|
||||
color="coral",
|
||||
alpha=0.7,
|
||||
)
|
||||
ax4.set_xlabel("Request ID")
|
||||
ax4.set_ylabel("Tokens")
|
||||
ax4.set_title("Tokens per Request (Stacked)")
|
||||
ax4.legend()
|
||||
ax4.grid(True, alpha=0.3, axis="y")
|
||||
|
||||
# Adjust layout to prevent overlap
|
||||
plt.tight_layout()
|
||||
|
||||
# Save figure
|
||||
plt.savefig(str(output_path), dpi=150, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
print(f"Dataset statistics plot saved to: {output_path}")
|
||||
@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import aiohttp
|
||||
@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format(
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def compute_result_filename(
|
||||
args: argparse.Namespace,
|
||||
model_id: str,
|
||||
label: str,
|
||||
current_dt: str,
|
||||
) -> str | None:
|
||||
"""Compute the result filename based on benchmark configuration.
|
||||
|
||||
Args:
|
||||
args: Command line arguments containing result configuration
|
||||
model_id: The model identifier
|
||||
label: The benchmark label
|
||||
current_dt: Current datetime string
|
||||
|
||||
Returns:
|
||||
The computed filename path or None if no result saving is requested
|
||||
"""
|
||||
if not (args.plot_timeline or args.save_result or args.append_result):
|
||||
return None
|
||||
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (
|
||||
f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
label = label or args.backend
|
||||
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
|
||||
return file_name
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument(
|
||||
@@ -1277,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
- "slow" will always use the slow tokenizer.\n
|
||||
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
||||
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
|
||||
- Other custom values can be supported via plugins.""",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
@@ -1535,6 +1580,30 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
"connecting to servers with self-signed certificates.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--plot-timeline",
|
||||
action="store_true",
|
||||
help="Generate an HTML timeline plot showing request execution. "
|
||||
"The plot will be saved alongside the results JSON file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeline-itl-thresholds",
|
||||
type=float,
|
||||
nargs=2,
|
||||
default=[25.0, 50.0],
|
||||
metavar=("THRESHOLD1", "THRESHOLD2"),
|
||||
help="ITL thresholds in milliseconds for timeline plot coloring. "
|
||||
"Specify two values to categorize inter-token latencies into three groups: "
|
||||
"below first threshold (green), between thresholds (orange), "
|
||||
"and above second threshold (red). Default: 25 50 (milliseconds).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot-dataset-stats",
|
||||
action="store_true",
|
||||
help="Generate a matplotlib figure with dataset statistics showing "
|
||||
"prompt tokens, output tokens, and combined token distributions.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> dict[str, Any]:
|
||||
return asyncio.run(main_async(args))
|
||||
@@ -1770,6 +1839,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Compute file_name once before using it for plots or saving results
|
||||
file_name = compute_result_filename(args, model_id, label, current_dt)
|
||||
|
||||
# Generate timeline plot if requested
|
||||
if args.plot_timeline:
|
||||
try:
|
||||
from vllm.benchmarks.plot import generate_timeline_plot
|
||||
|
||||
# Prepare per-request data for timeline
|
||||
per_request_data = []
|
||||
start_times = benchmark_result.get("start_times", [])
|
||||
ttfts = benchmark_result.get("ttfts", [])
|
||||
itls = benchmark_result.get("itls", [])
|
||||
input_lens = benchmark_result.get("input_lens", [])
|
||||
output_lens = benchmark_result.get("output_lens", [])
|
||||
|
||||
if start_times and ttfts and itls:
|
||||
for i in range(len(start_times)):
|
||||
# Calculate latency as ttft + sum of all itls
|
||||
latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i]
|
||||
|
||||
per_request_data.append(
|
||||
{
|
||||
"start_time": start_times[i],
|
||||
"ttft": ttfts[i],
|
||||
"itl": itls[i],
|
||||
"latency": latency,
|
||||
"prompt_len": input_lens[i],
|
||||
"output_tokens": output_lens[i],
|
||||
}
|
||||
)
|
||||
|
||||
timeline_path = Path(file_name).with_suffix(".timeline.html")
|
||||
# Convert thresholds from milliseconds to seconds
|
||||
itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds]
|
||||
generate_timeline_plot(
|
||||
per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Timeline plot requires detailed metrics. "
|
||||
"Ensure the benchmark completed successfully.",
|
||||
stacklevel=2,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2)
|
||||
|
||||
# Generate dataset statistics plot if requested
|
||||
if args.plot_dataset_stats:
|
||||
try:
|
||||
from vllm.benchmarks.plot import generate_dataset_stats_plot
|
||||
|
||||
# Prepare per-request data for dataset stats
|
||||
per_request_data = []
|
||||
input_lens = benchmark_result.get("input_lens", [])
|
||||
output_lens = benchmark_result.get("output_lens", [])
|
||||
|
||||
if input_lens and output_lens:
|
||||
for req_input_len, req_output_len in zip(input_lens, output_lens):
|
||||
per_request_data.append(
|
||||
{
|
||||
"prompt_len": req_input_len,
|
||||
"output_tokens": req_output_len,
|
||||
}
|
||||
)
|
||||
|
||||
stats_path = Path(file_name).with_suffix(".dataset_stats.png")
|
||||
generate_dataset_stats_plot(per_request_data, stats_path)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Dataset statistics plot requires input and "
|
||||
"output length data. Ensure the benchmark completed "
|
||||
"successfully.",
|
||||
stacklevel=2,
|
||||
)
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"Failed to generate dataset statistics plot: {e}", stacklevel=2
|
||||
)
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
for field in [
|
||||
@@ -1788,22 +1937,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
|
||||
# Save to file
|
||||
if args.save_result or args.append_result:
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (
|
||||
f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
label = label or args.backend
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(
|
||||
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
|
||||
) as outfile:
|
||||
|
||||
@@ -10,14 +10,14 @@ from .plot_pareto import SweepPlotParetoArgs
|
||||
from .plot_pareto import main as plot_pareto_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
from .serve_workload import SweepServeWorkloadArgs
|
||||
from .serve_workload import main as serve_workload_main
|
||||
from .startup import SweepStartupArgs
|
||||
from .startup import main as startup_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepServeWorkloadArgs, serve_workload_main),
|
||||
(SweepStartupArgs, startup_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
(SweepPlotParetoArgs, plot_pareto_main),
|
||||
|
||||
@@ -324,6 +324,11 @@ def _plot_fig(
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
if len(df) == 0:
|
||||
print(f"No data to plot. Filters: {filter_by}")
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
@@ -494,7 +499,7 @@ class SweepPlotArgs:
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
output_dir = Path(args.EXPERIMENT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
@@ -526,11 +531,9 @@ class SweepPlotArgs:
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
"EXPERIMENT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
@@ -570,13 +573,13 @@ class SweepPlotArgs:
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
default="total_token_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_ttft_ms",
|
||||
default="median_ttft_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -325,7 +325,7 @@ class SweepPlotParetoArgs:
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
output_dir = Path(args.EXPERIMENT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
@@ -342,9 +342,8 @@ class SweepPlotParetoArgs:
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
"EXPERIMENT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -135,17 +136,21 @@ def run_benchmark(
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
*,
|
||||
extra_parts: tuple[str, ...] = (),
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
if extra_parts:
|
||||
parts.extend(extra_parts)
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
return experiment_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
@@ -158,10 +163,10 @@ def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
@@ -175,11 +180,11 @@ def server_ctx(
|
||||
show_stdout: bool,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
if not _comb_needs_server(serve_comb, bench_params, output_dir):
|
||||
if not _comb_needs_server(serve_comb, bench_params, experiment_dir):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return run_server(
|
||||
@@ -211,10 +216,10 @@ def run_comb(
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
|
||||
return None
|
||||
@@ -253,10 +258,10 @@ def run_combs(
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
link_vars: list[tuple[str, str]],
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
@@ -266,22 +271,22 @@ def run_combs(
|
||||
show_stdout=show_stdout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
@@ -291,7 +296,7 @@ def run_combs(
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
@@ -305,11 +310,12 @@ class SweepServeArgs:
|
||||
server_ready_timeout: int
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
link_vars: list[tuple[str, str]]
|
||||
output_dir: Path
|
||||
experiment_name: str
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
link_vars: list[tuple[str, str]]
|
||||
resume: bool
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
@@ -336,6 +342,11 @@ class SweepServeArgs:
|
||||
|
||||
link_vars = cls.parse_link_vars(args.link_vars)
|
||||
|
||||
if args.experiment_name:
|
||||
experiment_name = args.experiment_name
|
||||
else:
|
||||
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
@@ -347,11 +358,12 @@ class SweepServeArgs:
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
link_vars=link_vars,
|
||||
output_dir=Path(args.output_dir),
|
||||
experiment_name=experiment_name,
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
link_vars=link_vars,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
)
|
||||
|
||||
@@ -388,6 +400,7 @@ class SweepServeArgs:
|
||||
default=300,
|
||||
help="Timeout in seconds to wait for the server to become ready.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
@@ -398,6 +411,16 @@ class SweepServeArgs:
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
@@ -413,7 +436,15 @@ class SweepServeArgs:
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
help="The main directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this experiment (defaults to current timestamp). "
|
||||
"Results will be stored under `output_dir/experiment_name`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
@@ -429,21 +460,10 @@ class SweepServeArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
action="store_true",
|
||||
help="Resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files "
|
||||
"under `output_dir/experiment_name`.",
|
||||
)
|
||||
|
||||
return parser
|
||||
@@ -458,33 +478,52 @@ class SweepServeArgs:
|
||||
pairs.append((a.strip(), b.strip()))
|
||||
return pairs
|
||||
|
||||
def resolve_experiment_dir(self) -> Path:
|
||||
experiment_dir = self.output_dir / self.experiment_name
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
if self.resume:
|
||||
if not experiment_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
|
||||
else:
|
||||
if experiment_dir.exists():
|
||||
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
return experiment_dir
|
||||
|
||||
@contextmanager
|
||||
def run_ctx(self, experiment_dir: Path):
|
||||
if self.dry_run:
|
||||
yield
|
||||
print(f"Experiment will be saved at: {experiment_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
yield
|
||||
print(f"Experiment has been saved at: {experiment_dir}")
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
"The script was terminated early. Use `--resume` "
|
||||
"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
with args.run_ctx(experiment_dir):
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
link_vars=args.link_vars,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
||||
@@ -1,305 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
|
||||
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
|
||||
|
||||
|
||||
def run_comb_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
sla_variable: SLAVariable,
|
||||
sla_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_sla = bench_comb | {sla_variable: sla_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_sla,
|
||||
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
|
||||
def explore_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of SLA iterations: {sla_iters}")
|
||||
|
||||
if sla_iters < 2:
|
||||
raise ValueError("`sla_iters` should be at least 2")
|
||||
|
||||
serial_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=1,
|
||||
)
|
||||
batch_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
|
||||
)
|
||||
|
||||
if serial_comb_data is None or batch_comb_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate SLA iterations.")
|
||||
print("[SLA END]")
|
||||
|
||||
return
|
||||
|
||||
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
|
||||
print(f"Serial inference: {sla_variable}={serial_sla_value}")
|
||||
|
||||
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
|
||||
print(f"Batch inference: {sla_variable}={batch_sla_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_sla_value` and `batch_sla_value` is small
|
||||
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
|
||||
inter_sla_values = sorted(set(map(round, inter_sla_values)))
|
||||
|
||||
inter_combs_data: list[dict[str, object]] = []
|
||||
for inter_sla_value in inter_sla_values:
|
||||
print(f"Exploring: {sla_variable}={inter_sla_value}")
|
||||
inter_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=inter_sla_value,
|
||||
)
|
||||
if inter_comb_data is not None:
|
||||
inter_combs_data.extend(inter_comb_data)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return serial_comb_data + inter_combs_data + batch_comb_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_variable=sla_variable,
|
||||
sla_iters=sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_variable: SLAVariable
|
||||
sla_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput space for determining SLAs."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
sla_group.add_argument(
|
||||
"--sla-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations used to explore the latency-throughput space. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`sla_variable` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
328
vllm/benchmarks/sweep/serve_workload.py
Normal file
328
vllm/benchmarks/sweep/serve_workload.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
WorkloadVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_workload_value(
|
||||
run_data: dict[str, object],
|
||||
workload_var: WorkloadVariable,
|
||||
):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if workload_var == "request_rate":
|
||||
return request_throughput
|
||||
if workload_var == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(workload_var)
|
||||
|
||||
|
||||
def _estimate_workload_avg(
|
||||
runs: list[dict[str, object]],
|
||||
workload_var: WorkloadVariable,
|
||||
):
|
||||
total = sum(_estimate_workload_value(run, workload_var) for run in runs)
|
||||
return total / len(runs)
|
||||
|
||||
|
||||
def run_comb_workload(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
workload_var: WorkloadVariable,
|
||||
workload_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_workload = bench_comb | {workload_var: workload_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_workload,
|
||||
link_vars=link_vars,
|
||||
base_path=_get_comb_base_path(
|
||||
experiment_dir,
|
||||
serve_comb,
|
||||
bench_comb,
|
||||
extra_parts=("WL-", f"{workload_var}={workload_value}"),
|
||||
),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def explore_comb_workloads(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
workload_var: WorkloadVariable,
|
||||
workload_iters: int,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
print("[WL START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of workload iterations: {workload_iters}")
|
||||
|
||||
if workload_iters < 2:
|
||||
raise ValueError("`workload_iters` should be at least 2")
|
||||
|
||||
dataset_size = DEFAULT_NUM_PROMPTS
|
||||
if "num_prompts" in bench_comb:
|
||||
dataset_size = int(bench_comb["num_prompts"]) # type: ignore
|
||||
else:
|
||||
for i, arg in enumerate(bench_cmd):
|
||||
if arg == "--num-prompts" and i + 1 < len(bench_cmd):
|
||||
dataset_size = int(bench_cmd[i + 1])
|
||||
break
|
||||
elif arg.startswith("--num-prompts="):
|
||||
dataset_size = int(arg.split("=", 1)[1])
|
||||
break
|
||||
|
||||
print(f"Dataset size: {dataset_size}")
|
||||
|
||||
serial_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {"max_concurrency": 1},
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=1,
|
||||
)
|
||||
batch_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb | {"max_concurrency": dataset_size},
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=dataset_size,
|
||||
)
|
||||
|
||||
if serial_workload_data is None or batch_workload_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate Workload iterations.")
|
||||
print("[WL END]")
|
||||
|
||||
return
|
||||
|
||||
serial_workload_value = math.ceil(
|
||||
_estimate_workload_avg(serial_workload_data, workload_var)
|
||||
)
|
||||
print(f"Serial inference: {workload_var}={serial_workload_value}")
|
||||
|
||||
batch_workload_value = math.floor(
|
||||
_estimate_workload_avg(batch_workload_data, workload_var)
|
||||
)
|
||||
print(f"Batch inference: {workload_var}={batch_workload_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_workload_value` and `batch_workload_value` is small
|
||||
inter_workload_values = np.linspace(
|
||||
serial_workload_value, batch_workload_value, workload_iters
|
||||
)[1:-1]
|
||||
inter_workload_values = sorted(set(map(round, inter_workload_values)))
|
||||
|
||||
inter_workloads_data: list[dict[str, object]] = []
|
||||
for inter_workload_value in inter_workload_values:
|
||||
print(f"Exploring: {workload_var}={inter_workload_value}")
|
||||
inter_workload_data = run_comb_workload(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
workload_var=workload_var,
|
||||
workload_value=inter_workload_value,
|
||||
)
|
||||
if inter_workload_data is not None:
|
||||
inter_workloads_data.extend(inter_workload_data)
|
||||
|
||||
print("[WL END]")
|
||||
|
||||
return serial_workload_data + inter_workloads_data + batch_workload_data
|
||||
|
||||
|
||||
def explore_combs_workloads(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
link_vars: list[tuple[str, str]],
|
||||
workload_var: WorkloadVariable,
|
||||
workload_iters: int,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if any(bench_comb.has_param(workload_var) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{workload_var}` in `bench_params` "
|
||||
"since it is supposed to be explored automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
experiment_dir=experiment_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_comb_workloads(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
link_vars=link_vars,
|
||||
workload_var=workload_var,
|
||||
workload_iters=workload_iters,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeWorkloadArgs(SweepServeArgs):
|
||||
workload_var: WorkloadVariable
|
||||
workload_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_workload"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput tradeoff for different workload levels."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
workload_var=args.workload_var,
|
||||
workload_iters=args.workload_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
workload_group = parser.add_argument_group("workload options")
|
||||
workload_group.add_argument(
|
||||
"--workload-var",
|
||||
type=str,
|
||||
choices=get_args(WorkloadVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
workload_group.add_argument(
|
||||
"--workload-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of workload levels to explore. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`workload_var` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeWorkloadArgs):
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
with args.run_ctx(experiment_dir):
|
||||
return explore_combs_workloads(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
link_vars=args.link_vars,
|
||||
workload_var=args.workload_var,
|
||||
workload_iters=args.workload_iters,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeWorkloadArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help)
|
||||
SweepServeWorkloadArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
@@ -111,7 +112,7 @@ def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
) -> Path:
|
||||
@@ -120,7 +121,8 @@ def _get_comb_base_path(
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if startup_comb:
|
||||
parts.extend(("STARTUP-", startup_comb.name))
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
return experiment_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
|
||||
@@ -225,7 +227,7 @@ def run_combs(
|
||||
*,
|
||||
serve_params: ParameterSweep,
|
||||
startup_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
experiment_dir: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
@@ -233,7 +235,7 @@ def run_combs(
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
for startup_comb in startup_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
|
||||
base_path = _get_comb_base_path(experiment_dir, serve_comb, startup_comb)
|
||||
comb_data = run_comb(
|
||||
startup_cmd,
|
||||
serve_comb=serve_comb,
|
||||
@@ -250,7 +252,7 @@ def run_combs(
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
combined_df.to_csv(experiment_dir / "summary.csv")
|
||||
return combined_df
|
||||
|
||||
|
||||
@@ -260,11 +262,11 @@ class SweepStartupArgs:
|
||||
serve_params: ParameterSweep
|
||||
startup_params: ParameterSweep
|
||||
output_dir: Path
|
||||
experiment_name: str
|
||||
num_runs: int
|
||||
show_stdout: bool
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
strict_params: bool
|
||||
resume: bool
|
||||
|
||||
parser_name: ClassVar[str] = "startup"
|
||||
parser_help: ClassVar[str] = (
|
||||
@@ -286,13 +288,19 @@ class SweepStartupArgs:
|
||||
startup_params = ParameterSweep.from_records([{}])
|
||||
|
||||
supported = _get_supported_startup_keys()
|
||||
strict_params = args.strict_params
|
||||
serve_params = _filter_params(
|
||||
serve_params, supported=supported, strict=args.strict_params
|
||||
serve_params, supported=supported, strict=strict_params
|
||||
)
|
||||
startup_params = _filter_params(
|
||||
startup_params, supported=supported, strict=args.strict_params
|
||||
startup_params, supported=supported, strict=strict_params
|
||||
)
|
||||
|
||||
if args.experiment_name:
|
||||
experiment_name = args.experiment_name
|
||||
else:
|
||||
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if args.num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
@@ -301,11 +309,11 @@ class SweepStartupArgs:
|
||||
serve_params=serve_params,
|
||||
startup_params=startup_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
experiment_name=experiment_name,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
strict_params=args.strict_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -316,6 +324,7 @@ class SweepStartupArgs:
|
||||
default="vllm bench startup",
|
||||
help="The command used to run the startup benchmark.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
@@ -331,12 +340,27 @@ class SweepStartupArgs:
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench startup` command.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
help="The main directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--experiment-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of this experiment (defaults to current timestamp). "
|
||||
"Results will be stored under `output_dir/experiment_name`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
@@ -357,43 +381,56 @@ class SweepStartupArgs:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
help="Resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files "
|
||||
"under `output_dir/experiment_name`.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def resolve_experiment_dir(self) -> Path:
|
||||
experiment_dir = self.output_dir / self.experiment_name
|
||||
|
||||
if self.resume:
|
||||
if not experiment_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
|
||||
else:
|
||||
if experiment_dir.exists():
|
||||
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
|
||||
|
||||
return experiment_dir
|
||||
|
||||
@contextmanager
|
||||
def run_ctx(self, experiment_dir: Path):
|
||||
if self.dry_run:
|
||||
yield
|
||||
print(f"Experiment will be saved at: {experiment_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
yield
|
||||
print(f"Experiment has been saved at: {experiment_dir}")
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
"The script was terminated early. Use `--resume` "
|
||||
"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def run_main(args: SweepStartupArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
experiment_dir = args.resolve_experiment_dir()
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
with args.run_ctx(experiment_dir):
|
||||
return run_combs(
|
||||
startup_cmd=args.startup_cmd,
|
||||
serve_params=args.serve_params,
|
||||
startup_params=args.startup_params,
|
||||
output_dir=output_dir,
|
||||
experiment_dir=experiment_dir,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
||||
@@ -282,49 +282,6 @@ class CompilerManager:
|
||||
maybe_key += f"{compile_range.start}_{compile_range.end}"
|
||||
maybe_key += f"_subgraph_{graph_index}"
|
||||
with self.compile_context(compile_range):
|
||||
# There is a compilation time optimization here.
|
||||
#
|
||||
# If the (input metadata, graph, compiler config) are the same, then
|
||||
# we want to avoid compiling the same artifact again. If we didn't
|
||||
# do this optimization, the backend compilation (InductorAdaptor or
|
||||
# InductorStandaloneAdaptor)
|
||||
# is able to cache hit and produce an artifact faster if it was
|
||||
# already created, but it is still a duplicate artifact that
|
||||
# requires unnecessary things e.g. disk IO.
|
||||
#
|
||||
# The optimization is: If the backend compilation cache hits,
|
||||
# then do an early return from the backend compilation and look up
|
||||
# which of the previous in-memory artifacts we created to reuse.
|
||||
#
|
||||
# We implemented this by monkey-patching torch (torch does not
|
||||
# easily expose the cache_key function), but in the future torch
|
||||
# should expose the cache_key function that we can just call
|
||||
# directly before invoking backend compilation.
|
||||
cache_key = None
|
||||
orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
|
||||
|
||||
def autograd_cache_key(*args, **kwargs):
|
||||
result = orig(*args, **kwargs)
|
||||
if result is None:
|
||||
return None
|
||||
nonlocal cache_key
|
||||
cache_key = result[0]
|
||||
if cache_key in self.loaded_artifacts:
|
||||
raise StopCompiling()
|
||||
return result
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with (
|
||||
# Graphs that are isometric (different node names but same
|
||||
# structure) should be treated as the same.
|
||||
torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
|
||||
autograd_cache_key,
|
||||
),
|
||||
):
|
||||
try:
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
@@ -332,11 +289,6 @@ class CompilerManager:
|
||||
compile_range,
|
||||
maybe_key,
|
||||
)
|
||||
except StopCompiling:
|
||||
assert cache_key is not None
|
||||
return self.loaded_artifacts[cache_key]
|
||||
if cache_key is not None and compiled_graph is not None:
|
||||
self.loaded_artifacts[cache_key] = compiled_graph
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
@@ -497,7 +449,7 @@ def wrap_with_cudagraph_if_needed(
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
return static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
runnable=piecewise_backend.graph.forward,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
@@ -780,7 +732,7 @@ class VllmBackend:
|
||||
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
|
||||
|
||||
def configure_post_pass(self) -> None:
|
||||
# self.pass_manager.configure(self.vllm_config)
|
||||
self.pass_manager.configure(self.vllm_config)
|
||||
|
||||
# Post-grad custom passes are run using the post_grad_custom_post_pass
|
||||
# hook. If a pass for that hook exists, add it to the pass manager.
|
||||
@@ -846,7 +798,7 @@ class VllmBackend:
|
||||
),
|
||||
)
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any], **kwargs) -> Any:
|
||||
from .caching import (
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
@@ -988,7 +940,7 @@ class VllmBackend:
|
||||
assert not self._called, "VllmBackend can only be called once"
|
||||
|
||||
self.graph = graph
|
||||
self.configure_post_pass()
|
||||
# self.configure_post_pass()
|
||||
|
||||
if self.compilation_config.use_inductor_graph_partition:
|
||||
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
|
||||
self.loaded_submodule_store = {}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def patch_pytree_map_over_slice():
|
||||
pytree._private_register_pytree_node(
|
||||
slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x)
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
pytree._deregister_pytree_node(slice)
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||
with (
|
||||
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
|
||||
patch_pytree_map_over_slice(),
|
||||
):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
@@ -261,6 +277,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
with patch_pytree_map_over_slice():
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
state["graph_module"].recompile()
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
|
||||
@@ -184,6 +184,47 @@ def is_compile_cache_enabled(
|
||||
)
|
||||
|
||||
|
||||
def _patch_standalone_compile_atomic_save() -> None:
|
||||
"""Backport of pytorch/pytorch#162432 for torch < 2.10.0.
|
||||
|
||||
Patches CompiledArtifact.save() to use write_atomic for binary format,
|
||||
preventing corrupt cache files when multiple processes compile
|
||||
concurrently.
|
||||
"""
|
||||
from torch._inductor.codecache import write_atomic
|
||||
from torch._inductor.standalone_compile import CompiledArtifact as cls
|
||||
|
||||
if getattr(cls.save, "_vllm_patched", False):
|
||||
return
|
||||
|
||||
original_save = cls.save
|
||||
|
||||
def _save(
|
||||
self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None:
|
||||
if format != "binary":
|
||||
return original_save(self, path=path, format=format)
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.codecache import torch_key
|
||||
from torch.utils._appending_byte_serializer import BytesWriter
|
||||
|
||||
with dynamo_timed("CompiledArtifact.save"):
|
||||
assert self._artifacts is not None
|
||||
artifact_bytes, cache_info = self._artifacts
|
||||
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
|
||||
key = cache_info.aot_autograd_artifacts[0]
|
||||
assert not os.path.isdir(path)
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
write_atomic(path, writer.to_bytes())
|
||||
|
||||
_save._vllm_patched = True # type: ignore[attr-defined]
|
||||
cls.save = _save # type: ignore[assignment]
|
||||
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
@@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
if not is_torch_equal_or_newer("2.10.0"):
|
||||
_patch_standalone_compile_atomic_save()
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
@@ -224,7 +267,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
# set_functorch_config()
|
||||
|
||||
if compile_range.is_single_size():
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
@@ -325,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format=self.save_format
|
||||
)
|
||||
compilation_counter.num_compiled_artifacts_loaded += 1
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
@@ -395,7 +439,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
# set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
|
||||
@@ -29,6 +29,8 @@ class CompilationCounter:
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# The number of standalone_compile compiled artifacts loaded from cache
|
||||
num_compiled_artifacts_loaded: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -298,12 +299,10 @@ class CUDAGraphWrapper:
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
# output = weak_ref_tensors(output)
|
||||
output = self.weak_ref_tensors_with_intermediate(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
# entry.output = weak_ref_tensors(output)
|
||||
entry.output = self.weak_ref_tensors_with_intermediate(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
"sum",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
|
||||
@@ -5,7 +5,6 @@ import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic128Sym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
logger.debug(
|
||||
"%s Replaced %s patterns", self.__class__.__name__, self.matched_count
|
||||
)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
def __init__(self) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
at2 = self.quant_matcher(at1)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
@@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
|
||||
rope_targets = [torch.ops._C.rotary_embedding.default]
|
||||
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
rope_targets.append(
|
||||
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
|
||||
)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
if at_target in rope_targets:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
@@ -298,9 +298,9 @@ class PiecewiseBackend:
|
||||
else list(args)
|
||||
)
|
||||
|
||||
with (
|
||||
torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
):
|
||||
# with (
|
||||
# torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
# ):
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
|
||||
@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
|
||||
|
||||
def reset_compile_wrapper(model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Clean up compiled model and captured CUDA graphs for elastic EP.
|
||||
"""
|
||||
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
|
||||
model, "model"
|
||||
):
|
||||
model = model.model
|
||||
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
|
||||
return
|
||||
# model.do_not_compile is set by the @support_torch_compile decorator
|
||||
if hasattr(model, "do_not_compile") and model.do_not_compile:
|
||||
return
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
|
||||
# reset the compilation counter
|
||||
compilation_counter.num_models_seen = 0
|
||||
compilation_counter.num_graphs_seen = 0
|
||||
compilation_counter.num_piecewise_graphs_seen = 0
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen = 0
|
||||
compilation_counter.num_backend_compilations = 0
|
||||
compilation_counter.num_gpu_runner_capture_triggers = 0
|
||||
compilation_counter.num_cudagraph_captured = 0
|
||||
compilation_counter.num_inductor_compiles = 0
|
||||
compilation_counter.num_eager_compiles = 0
|
||||
compilation_counter.num_cache_entries_updated = 0
|
||||
compilation_counter.num_compiled_artifacts_saved = 0
|
||||
compilation_counter.stock_torch_compile_count = 0
|
||||
|
||||
# Clear the AOT compiled function so the model is forced to
|
||||
# recompile on the next call. Without this, decorators.py
|
||||
# __call__ uses the stale aot_compiled_fn whose torchinductor
|
||||
# kernels have old parameters (expert_map size for example)
|
||||
# baked in as compile-time constants.
|
||||
if hasattr(model, "aot_compiled_fn"):
|
||||
model.aot_compiled_fn = None
|
||||
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
|
||||
model.was_aot_compile_fn_loaded_from_disk = False
|
||||
|
||||
# Reset the cache_dir so VllmBackend recomputes the hash
|
||||
# (data_parallel_size changed, so the config hash differs).
|
||||
compilation_config = model.vllm_config.compilation_config
|
||||
compilation_config.cache_dir = ""
|
||||
compilation_config.local_cache_dir = ""
|
||||
|
||||
model.__class__.forward.__code__ = model.original_code_object()
|
||||
TorchCompileWithNoGuardsWrapper.__init__(model)
|
||||
|
||||
@@ -16,8 +16,8 @@ class AttentionConfig:
|
||||
backend: AttentionBackendEnum | None = None
|
||||
"""Attention backend to use. If None, will be selected automatically."""
|
||||
|
||||
flash_attn_version: Literal[2, 3] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2 or 3).
|
||||
flash_attn_version: Literal[2, 3, 4] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2, 3, or 4).
|
||||
Only valid when using the flash-attention backend."""
|
||||
|
||||
use_prefill_decode_attention: bool = False
|
||||
|
||||
@@ -87,8 +87,15 @@ class CUDAGraphMode(enum.Enum):
|
||||
def separate_routine(self) -> bool:
|
||||
return isinstance(self.value, tuple)
|
||||
|
||||
def valid_runtime_modes(self) -> bool:
|
||||
return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
|
||||
def decode_use_graph(self) -> bool:
|
||||
return self.decode_mode() == CUDAGraphMode.FULL
|
||||
|
||||
@classmethod
|
||||
def valid_runtime_modes(cls) -> frozenset["CUDAGraphMode"]:
|
||||
return frozenset({cls.NONE, cls.PIECEWISE, cls.FULL})
|
||||
|
||||
def is_valid_runtime_mode(self) -> bool:
|
||||
return self in CUDAGraphMode.valid_runtime_modes()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
@@ -385,7 +392,7 @@ class CompilationConfig:
|
||||
Please use mode. Currently all levels are mapped to mode.
|
||||
"""
|
||||
# Top-level Compilation control
|
||||
mode: CompilationMode = Field(default=None)
|
||||
mode: CompilationMode = Field(default=CompilationMode.NONE)
|
||||
"""The compilation approach used for torch.compile-based compilation of the
|
||||
model.
|
||||
|
||||
@@ -503,7 +510,7 @@ class CompilationConfig:
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
|
||||
|
||||
# CudaGraph compilation
|
||||
cudagraph_mode: CUDAGraphMode = Field(default=None)
|
||||
cudagraph_mode: CUDAGraphMode = Field(default=CUDAGraphMode.FULL_DECODE_ONLY)
|
||||
"""
|
||||
The mode of the cudagraph:
|
||||
|
||||
@@ -1003,6 +1010,7 @@ class CompilationConfig:
|
||||
# https://github.com/vllm-project/vllm/issues/33267
|
||||
if not self.use_inductor_graph_partition:
|
||||
self.splitting_ops.append("vllm::unified_kv_cache_update")
|
||||
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
|
||||
|
||||
elif len(self.splitting_ops) == 0:
|
||||
if (
|
||||
@@ -1045,7 +1053,7 @@ class CompilationConfig:
|
||||
"are optimized for prefill and are incompatible with CUDA Graphs. "
|
||||
"In order to use CUDA Graphs for decode-optimized workloads, "
|
||||
"use --all2all-backend with another option, such as "
|
||||
"deepep_low_latency, pplx, or allgather_reducescatter."
|
||||
"deepep_low_latency or allgather_reducescatter."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
|
||||
@@ -50,8 +50,6 @@ from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
import os
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
@@ -128,6 +126,7 @@ class ModelConfig:
|
||||
- "slow" will always use the slow tokenizer.\n
|
||||
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
|
||||
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
|
||||
- Other custom values can be supported via plugins."""
|
||||
trust_remote_code: bool = False
|
||||
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
||||
@@ -463,8 +462,6 @@ class ModelConfig:
|
||||
|
||||
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if self.override_attention_dtype is not None and not current_platform.is_rocm():
|
||||
warnings.warn(
|
||||
"override-attention-dtype is set but not using ROCm platform",
|
||||
@@ -474,9 +471,8 @@ class ModelConfig:
|
||||
if self.enable_sleep_mode and not current_platform.is_sleep_mode_available():
|
||||
raise ValueError("Sleep mode is not supported on current platform.")
|
||||
|
||||
temp_hf_config_path = os.environ.get("CUSTOM_QUANT_CONFIG", None)
|
||||
hf_config = get_config(
|
||||
temp_hf_config_path or self.hf_config_path or self.model,
|
||||
self.hf_config_path or self.model,
|
||||
self.trust_remote_code,
|
||||
self.revision,
|
||||
self.code_revision,
|
||||
@@ -622,6 +618,16 @@ class ModelConfig:
|
||||
self._try_verify_and_update_model_config()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
import os
|
||||
enforce_cuda_graph = os.environ.get("VLLM_ENFORCE_CUDA_GRAPH",None)
|
||||
if enforce_cuda_graph is not None and enforce_cuda_graph in ["1", "y", "Y"]:
|
||||
self.enforce_eager = False
|
||||
else:
|
||||
self.enforce_eager = True
|
||||
logger.warning_once(
|
||||
"Please export VLLM_ENFORCE_CUDA_GRAPH=1 to enable cuda graph. "
|
||||
"For now, cuda graph is not used and --enforce-eager is disabled ,"
|
||||
"we are trying to use cuda graph as the default mode")
|
||||
self._verify_bnb_config()
|
||||
|
||||
def get_model_arch_config(
|
||||
@@ -886,6 +892,7 @@ class ModelConfig:
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"modelopt_mxfp8",
|
||||
"modelopt_mixed",
|
||||
"petit_nvfp4",
|
||||
# Ensure heavy backends are probed last to avoid unnecessary
|
||||
# imports during override detection (e.g., MXFP4 imports Triton)
|
||||
@@ -942,8 +949,6 @@ class ModelConfig:
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}."
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
|
||||
if self.quantization in me_quant.DEPRECATED_QUANTIZATION_METHODS:
|
||||
@@ -1813,8 +1818,6 @@ def _resolve_auto_dtype(
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
supported_dtypes = [
|
||||
dtype
|
||||
for dtype in current_platform.supported_dtypes
|
||||
|
||||
@@ -152,7 +152,6 @@ class ParallelConfig:
|
||||
|
||||
- "naive": Naive all2all implementation using broadcasts\n
|
||||
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
|
||||
- "pplx": Use pplx kernels\n
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels\n
|
||||
- "deepep_low_latency": Use deepep low-latency kernels\n
|
||||
- "mori": Use mori kernels\n
|
||||
@@ -166,6 +165,9 @@ class ParallelConfig:
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_elastic_ep: bool = False
|
||||
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
ubatch_size: int = 0
|
||||
@@ -245,6 +247,34 @@ class ParallelConfig:
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
"""
|
||||
|
||||
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
It is a list of list[int], with each inner list contains a set of 3 ports
|
||||
to be used for setting up the stateless CPU/device/TCPStore groups
|
||||
in StatelessGroupCoordinator. The number of inner lists is equal to
|
||||
the number of DP groups,
|
||||
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
|
||||
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
|
||||
"""
|
||||
|
||||
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
|
||||
"""
|
||||
|
||||
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
|
||||
Same topology as EP but separate NCCL communicator to avoid deadlocks.
|
||||
"""
|
||||
|
||||
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
|
||||
"""List of open ports for stateless world group when enable_elastic_ep is True.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
len(self._stateless_world_group_port_list) == 1,
|
||||
"""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
@@ -310,6 +340,13 @@ class ParallelConfig:
|
||||
f"but found: {self._api_process_rank}"
|
||||
)
|
||||
|
||||
if self.all2all_backend == "pplx":
|
||||
logger.warning(
|
||||
"The 'pplx' all2all backend has been removed. "
|
||||
"Falling back to 'allgather_reducescatter'."
|
||||
)
|
||||
self.all2all_backend = "allgather_reducescatter"
|
||||
|
||||
if self.data_parallel_size_local > self.data_parallel_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_size_local ({self.data_parallel_size_local}) "
|
||||
@@ -396,7 +433,67 @@ class ParallelConfig:
|
||||
|
||||
return answer
|
||||
|
||||
def stateless_init_dp_group(self) -> ProcessGroup:
|
||||
def allocate_elastic_ep_ports(self) -> None:
|
||||
"""Allocate all ports for elastic EP (stateless groups + DP master).
|
||||
|
||||
Must be called AFTER ray.init() so that ports claimed by Ray's
|
||||
idle worker pool are already in use and won't be returned by
|
||||
get_open_ports_list().
|
||||
"""
|
||||
if not self.enable_elastic_ep:
|
||||
return
|
||||
if self._stateless_world_group_port_list:
|
||||
return
|
||||
|
||||
num_world_groups = 1
|
||||
dp_size = self.data_parallel_size
|
||||
ep_size = self.data_parallel_size * self.world_size_across_dp
|
||||
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
|
||||
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
|
||||
num_eplb_groups = num_ep_groups
|
||||
total_stateless_ports = (
|
||||
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
|
||||
) * 3
|
||||
num_dp_master_ports = 5
|
||||
|
||||
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
|
||||
|
||||
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
all_ports = all_ports[:-num_dp_master_ports]
|
||||
|
||||
self._stateless_world_group_port_list = [
|
||||
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
|
||||
]
|
||||
start_idx = num_world_groups * 3
|
||||
self._stateless_dp_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_dp_groups * 3
|
||||
self._stateless_ep_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
|
||||
]
|
||||
start_idx += num_ep_groups * 3
|
||||
self._stateless_eplb_group_port_list = [
|
||||
all_ports[i : i + 3]
|
||||
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
|
||||
]
|
||||
|
||||
def get_next_stateless_world_group_port(self) -> list[int]:
|
||||
return self._stateless_world_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_dp_group_port(self) -> list[int]:
|
||||
return self._stateless_dp_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_ep_group_port(self) -> list[int]:
|
||||
return self._stateless_ep_group_port_list.pop()
|
||||
|
||||
def get_next_stateless_eplb_group_port(self) -> list[int]:
|
||||
return self._stateless_eplb_group_port_list.pop()
|
||||
|
||||
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
|
||||
# NOTE: In high-concurrency scenarios multiple processes
|
||||
# can pick the same (currently free) port through a race
|
||||
# condition when calling `get_open_port()`. When the first
|
||||
@@ -420,7 +517,8 @@ class ParallelConfig:
|
||||
self.get_next_dp_init_port(),
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend=current_platform.dist_backend,
|
||||
backend="gloo",
|
||||
return_store=return_store,
|
||||
)
|
||||
except DistNetworkError as e:
|
||||
# We only want to retry when the root cause is EADDRINUSE.
|
||||
@@ -442,7 +540,6 @@ class ParallelConfig:
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
@property
|
||||
def use_sequence_parallel_moe(self) -> bool:
|
||||
return (
|
||||
@@ -556,6 +653,21 @@ class ParallelConfig:
|
||||
logger.info("Using external launcher for distributed inference.")
|
||||
self.world_size *= self.data_parallel_size
|
||||
|
||||
if self.enable_elastic_ep:
|
||||
if not self.enable_eplb:
|
||||
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Elastic EP is not supported with pipeline parallelism "
|
||||
f"(pipeline_parallel_size={self.pipeline_parallel_size})."
|
||||
)
|
||||
if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
|
||||
raise NotImplementedError(
|
||||
"Elastic EP is not compatible with data_parallel_external_lb "
|
||||
"or data_parallel_hybrid_lb. Elastic EP relies on a single API "
|
||||
"server and core client to coordinate scale up/down."
|
||||
)
|
||||
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
@@ -568,9 +680,12 @@ class ParallelConfig:
|
||||
"Set data_parallel_rank to %d automatically.",
|
||||
self.data_parallel_rank,
|
||||
)
|
||||
if not self.enable_elastic_ep:
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
self.data_parallel_master_port = (
|
||||
self._data_parallel_master_port_list.pop()
|
||||
)
|
||||
|
||||
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
|
||||
raise ValueError(
|
||||
@@ -597,7 +712,7 @@ class ParallelConfig:
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
logger.info("Disabling V1 multiprocessing for external launcher.")
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
@@ -659,6 +774,17 @@ class ParallelConfig:
|
||||
"backend is mp, uni or external_launcher."
|
||||
)
|
||||
|
||||
if (
|
||||
self.all2all_backend in ("allgather_reducescatter", "naive")
|
||||
and self.eplb_config.use_async
|
||||
):
|
||||
logger.warning(
|
||||
"Async EPLB causes hangs with the '%s' all2all backend. "
|
||||
"Forcing synchronous EPLB.",
|
||||
self.all2all_backend,
|
||||
)
|
||||
self.eplb_config.use_async = False
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
|
||||
"pangu_ultra_moe_mtp",
|
||||
"step3p5_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
@@ -77,12 +78,24 @@ class SpeculativeConfig:
|
||||
|
||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||
`prompt_lookup_min` should be considered."""
|
||||
enable_multi_layers_mtp: bool = False
|
||||
"""If set to True, the MTP method will run multiple layers of MTP
|
||||
speculator. If set to False, it will run only one layer of MTP speculator.
|
||||
This is only effective when the method is set to `mtp`."""
|
||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
draft_pipeline_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of pipeline parallelism for the draft model.
|
||||
|
||||
Defaults to the target model's pipeline parallel size. Set this to 1 to
|
||||
run the drafter locally on the last target PP stage."""
|
||||
tensor_parallel_size: int | None = None
|
||||
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
|
||||
warn users when they mistakenly provide the wrong argument."""
|
||||
pipeline_parallel_size: int | None = None
|
||||
"""Users should pass "draft_pipeline_parallel_size". This parameter's
|
||||
purpose is to warn users when they mistakenly provide the wrong argument."""
|
||||
|
||||
# Draft model configuration
|
||||
quantization: me_quant.QuantizationMethods | None = None
|
||||
@@ -181,9 +194,22 @@ class SpeculativeConfig:
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
# Eagle3 and extract_hidden_states affect the computation graph because
|
||||
# they return intermediate hidden states in addition to the final hidden state.
|
||||
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
|
||||
factors.append(uses_aux_hidden_states)
|
||||
|
||||
# The specific layers used also affect the computation graph
|
||||
if uses_aux_hidden_states and self.draft_model_config is not None:
|
||||
layer_ids = getattr(
|
||||
self.draft_model_config.hf_config,
|
||||
"eagle_aux_hidden_state_layer_ids",
|
||||
None,
|
||||
)
|
||||
if layer_ids is not None:
|
||||
# Convert to tuple to make it hashable
|
||||
factors.append(tuple(layer_ids))
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@@ -352,6 +378,8 @@ class SpeculativeConfig:
|
||||
self.model = "ngram"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
elif self.method == "extract_hidden_states":
|
||||
self.model = "extract_hidden_states"
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
@@ -394,6 +422,34 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
elif self.method == "extract_hidden_states":
|
||||
from vllm.transformers_utils.configs.extract_hidden_states import (
|
||||
ExtractHiddenStatesConfig,
|
||||
)
|
||||
|
||||
# ExtractHiddenStatesModel is instantiated manually in load_model()
|
||||
# We just need to store the target model config for KV cache shape info
|
||||
self.model = "extract_hidden_states"
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if hasattr(self.draft_model_config, "hf_config"):
|
||||
hf_config = self.draft_model_config.hf_config.to_dict()
|
||||
elif (
|
||||
isinstance(self.draft_model_config, dict)
|
||||
and "hf_config" in self.draft_model_config
|
||||
):
|
||||
hf_config = self.draft_model_config["hf_config"]
|
||||
else:
|
||||
hf_config = {}
|
||||
|
||||
self.draft_model_config = copy.copy(self.target_model_config)
|
||||
self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
|
||||
self.draft_model_config.hf_config, **hf_config
|
||||
)
|
||||
self.update_arch_()
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
@@ -439,7 +495,10 @@ class SpeculativeConfig:
|
||||
MTPModelTypes
|
||||
):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
if (
|
||||
self.enable_multi_layers_mtp is False
|
||||
and self.num_speculative_tokens > 1
|
||||
):
|
||||
logger.warning(
|
||||
"Enabling num_speculative_tokens > 1 will run "
|
||||
"multiple times of forward on same MTP layer"
|
||||
@@ -478,23 +537,8 @@ class SpeculativeConfig:
|
||||
method=self.method,
|
||||
model_type="eagle",
|
||||
)
|
||||
# EAGLEConfig primarily updates architectures, so update
|
||||
# all architectures-related fields in draft_model_config
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = (
|
||||
self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
self.update_arch_()
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
@@ -510,6 +554,17 @@ class SpeculativeConfig:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif (
|
||||
self.method == "mtp"
|
||||
and self.enable_multi_layers_mtp
|
||||
and self.num_speculative_tokens > n_predict
|
||||
):
|
||||
logger.warning_once(
|
||||
"For multi_layer_eagle, num_speculative_tokens "
|
||||
"is greater than the layer_num, adjusting to "
|
||||
"layer_num"
|
||||
)
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif (
|
||||
self.num_speculative_tokens > n_predict
|
||||
and self.num_speculative_tokens % n_predict != 0
|
||||
@@ -555,9 +610,17 @@ class SpeculativeConfig:
|
||||
)
|
||||
)
|
||||
|
||||
self.draft_pipeline_parallel_size = (
|
||||
SpeculativeConfig._verify_and_get_draft_pp(
|
||||
self.target_parallel_config,
|
||||
self.draft_pipeline_parallel_size,
|
||||
)
|
||||
)
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config, self.draft_tensor_parallel_size
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_pipeline_parallel_size,
|
||||
)
|
||||
)
|
||||
return self
|
||||
@@ -671,17 +734,61 @@ class SpeculativeConfig:
|
||||
)
|
||||
return speculative_draft_tensor_parallel_size
|
||||
|
||||
@staticmethod
|
||||
def _verify_and_get_draft_pp(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_pipeline_parallel_size: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
Verifies and adjusts the pipeline parallel size for a draft model
|
||||
specified using speculative_draft_pipeline_parallel_size.
|
||||
"""
|
||||
if speculative_draft_pipeline_parallel_size is None:
|
||||
return target_parallel_config.pipeline_parallel_size
|
||||
|
||||
if speculative_draft_pipeline_parallel_size not in (
|
||||
1,
|
||||
target_parallel_config.pipeline_parallel_size,
|
||||
):
|
||||
raise ValueError(
|
||||
f"{speculative_draft_pipeline_parallel_size=} cannot be "
|
||||
"other value than 1 or target model "
|
||||
f"pipeline_parallel_size="
|
||||
f"{target_parallel_config.pipeline_parallel_size}"
|
||||
)
|
||||
return speculative_draft_pipeline_parallel_size
|
||||
|
||||
def update_arch_(self):
|
||||
"""
|
||||
EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
|
||||
architectures-related fields in self.draft_model_config
|
||||
"""
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
speculative_draft_pipeline_parallel_size: int,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
This is mostly a copy of the target parallel config, except the tp_size.
|
||||
This is mostly a copy of the target parallel config, except the tp/pp
|
||||
sizes used by the draft model.
|
||||
"""
|
||||
draft_parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
|
||||
pipeline_parallel_size=speculative_draft_pipeline_parallel_size,
|
||||
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
|
||||
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
|
||||
@@ -699,6 +806,12 @@ class SpeculativeConfig:
|
||||
"'tensor_parallel_size' is not a valid argument in the "
|
||||
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
|
||||
)
|
||||
if self.pipeline_parallel_size is not None:
|
||||
raise ValueError(
|
||||
"'pipeline_parallel_size' is not a valid argument in the "
|
||||
"speculative_config. Please pass "
|
||||
"'draft_pipeline_parallel_size' instead."
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
@@ -718,7 +831,7 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config
|
||||
)
|
||||
|
||||
eagle3_target_supported = [
|
||||
aux_hidden_states_supported = [
|
||||
"llama",
|
||||
"qwen",
|
||||
"minicpm",
|
||||
@@ -729,16 +842,16 @@ class SpeculativeConfig:
|
||||
"nemotron_h",
|
||||
]
|
||||
if (
|
||||
self.method == "eagle3"
|
||||
self.method in ("eagle3", "extract_hidden_states")
|
||||
and self.target_model_config
|
||||
and not any(
|
||||
supported_model in self.target_model_config.hf_text_config.model_type
|
||||
for supported_model in eagle3_target_supported
|
||||
for supported_model in aux_hidden_states_supported
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
f"{self.method} is only supported for {aux_hidden_states_supported}"
|
||||
f" models. Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
)
|
||||
self.verify_equal_vocab_size_if_draft_model()
|
||||
return self
|
||||
@@ -782,8 +895,65 @@ class SpeculativeConfig:
|
||||
def uses_draft_model(self) -> bool:
|
||||
return self.method == "draft_model"
|
||||
|
||||
def uses_extract_hidden_states(self) -> bool:
|
||||
return self.method == "extract_hidden_states"
|
||||
|
||||
def needs_partial_pp_draft_remap(
|
||||
self, target_parallel_config: ParallelConfig
|
||||
) -> bool:
|
||||
"""Whether draft PP is smaller than target PP and needs rank remap."""
|
||||
if self.draft_parallel_config is None:
|
||||
return False
|
||||
return (
|
||||
target_parallel_config.pipeline_parallel_size
|
||||
> self.draft_parallel_config.pipeline_parallel_size
|
||||
)
|
||||
|
||||
def resolve_partial_pp_draft_rank(
|
||||
self, target_parallel_config: ParallelConfig
|
||||
) -> int:
|
||||
"""Map a target rank to the local draft rank for partial-PP drafting.
|
||||
|
||||
Currently this only supports running the draft model with `draft_pp=1`
|
||||
on the last target PP stage.
|
||||
"""
|
||||
if not self.needs_partial_pp_draft_remap(target_parallel_config):
|
||||
return target_parallel_config.rank
|
||||
|
||||
assert self.draft_parallel_config is not None
|
||||
draft_pp = self.draft_parallel_config.pipeline_parallel_size
|
||||
if draft_pp != 1:
|
||||
raise ValueError(
|
||||
"Partial pp drafter rank remapping only supports "
|
||||
"draft_pipeline_parallel_size=1 when target PP is larger."
|
||||
)
|
||||
|
||||
target_tp = target_parallel_config.tensor_parallel_size
|
||||
draft_tp = self.draft_parallel_config.tensor_parallel_size
|
||||
if draft_tp != target_tp:
|
||||
raise ValueError(
|
||||
"Partial pp drafter rank remapping requires "
|
||||
"draft_tensor_parallel_size to equal target tensor_parallel_size. "
|
||||
f"Got draft_tp={draft_tp}, target_tp={target_tp}."
|
||||
)
|
||||
|
||||
target_pp = target_parallel_config.pipeline_parallel_size
|
||||
target_rank = target_parallel_config.rank
|
||||
target_pp_rank = target_rank // target_tp
|
||||
target_tp_rank = target_rank % target_tp
|
||||
if target_pp_rank != target_pp - 1:
|
||||
raise ValueError(
|
||||
"Partial pp drafter should only run on the last "
|
||||
f"pipeline stage, but got pp rank {target_pp_rank} / {target_pp}"
|
||||
)
|
||||
return target_tp_rank
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
model = (
|
||||
None
|
||||
if method in ("ngram", "suffix", "extract_hidden_states")
|
||||
else self.draft_model_config.model
|
||||
)
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
|
||||
@@ -126,6 +126,9 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
|
||||
# tp-dp combination broken:
|
||||
# https://github.com/vllm-project/vllm/issues/34458
|
||||
and cfg.parallel_config.data_parallel_size == 1
|
||||
# tp-pp combination broken:
|
||||
# https://github.com/vllm-project/vllm/issues/35426
|
||||
and cfg.parallel_config.pipeline_parallel_size == 1
|
||||
)
|
||||
|
||||
|
||||
@@ -857,7 +860,7 @@ class VllmConfig:
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
else:
|
||||
# Compute SP threshold early; disable if None (model too
|
||||
# small) before +rms_norm gets forced into custom_ops.
|
||||
# small for SP to be beneficial).
|
||||
pass_config = self.compilation_config.pass_config
|
||||
if pass_config.sp_min_token_num is None:
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
@@ -880,15 +883,13 @@ class VllmConfig:
|
||||
self.compilation_config.pass_config.enable_sp = False
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
if "-rms_norm" in self.compilation_config.custom_ops:
|
||||
logger.warning(
|
||||
"RMS norm force disabled, sequence parallelism might break"
|
||||
)
|
||||
else:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
|
||||
|
||||
if self.compilation_config.fast_moe_cold_start is None:
|
||||
if HAS_OPAQUE_TYPE:
|
||||
# On torch >= 2.11 the hoisted OpaqueObject approach supersedes
|
||||
# fast_moe_cold_start, so force it off.
|
||||
self.compilation_config.fast_moe_cold_start = False
|
||||
elif self.compilation_config.fast_moe_cold_start is None:
|
||||
# resolve default behavior: try to be as safe as possible
|
||||
# this config is unsafe if any spec decoding draft model has a MOE.
|
||||
# We'll conservatively turn it off if we see spec decoding.
|
||||
@@ -907,9 +908,9 @@ class VllmConfig:
|
||||
):
|
||||
logger.warning_once(
|
||||
"Pooling models do not support full cudagraphs. "
|
||||
"Overriding cudagraph_mode to PIECEWISE."
|
||||
"Overriding cudagraph_mode to NONE."
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif (
|
||||
model_config.is_encoder_decoder
|
||||
and self.compilation_config.cudagraph_mode
|
||||
@@ -924,6 +925,33 @@ class VllmConfig:
|
||||
CUDAGraphMode.FULL_DECODE_ONLY
|
||||
)
|
||||
|
||||
# Check if KV connector requires PIECEWISE mode for CUDA graphs
|
||||
if (
|
||||
self.kv_transfer_config is not None
|
||||
and self.kv_transfer_config.is_kv_transfer_instance
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
# Lazy import to avoid circular dependencies
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory,
|
||||
)
|
||||
|
||||
connector_cls = KVConnectorFactory.get_connector_class(
|
||||
self.kv_transfer_config
|
||||
)
|
||||
if connector_cls.requires_piecewise_for_cudagraph(
|
||||
self.kv_transfer_config.kv_connector_extra_config
|
||||
):
|
||||
logger.warning_once(
|
||||
"KV connector %s requires PIECEWISE CUDA graph mode "
|
||||
"due to layerwise async operations that cannot be "
|
||||
"captured in CUDA graphs. "
|
||||
"Overriding cudagraph_mode from %s to PIECEWISE.",
|
||||
connector_cls.__name__,
|
||||
self.compilation_config.cudagraph_mode.name,
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
# disable cudagraph when enforce eager execution
|
||||
if self.model_config is not None and self.model_config.enforce_eager:
|
||||
logger.info("Cudagraph is disabled under eager mode")
|
||||
@@ -1114,6 +1142,20 @@ class VllmConfig:
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
def is_ixserver_connector(kv_transfer_config) -> bool:
|
||||
if kv_transfer_config is not None and hasattr(
|
||||
kv_transfer_config, "kv_connector"
|
||||
):
|
||||
connector = kv_transfer_config.kv_connector
|
||||
if isinstance(connector, str):
|
||||
connector_name = connector
|
||||
else:
|
||||
connector_name = getattr(
|
||||
type(connector), "__name__", str(connector)
|
||||
)
|
||||
return "IxServer" in connector_name
|
||||
return False
|
||||
|
||||
# Hybrid KV cache manager (HMA) runtime rules:
|
||||
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
|
||||
# disables it
|
||||
@@ -1154,7 +1196,10 @@ class VllmConfig:
|
||||
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
|
||||
# Default to disable HMA, but only if the user didn't express a preference.
|
||||
if self.kv_transfer_config is not None:
|
||||
if is_ixserver_connector(self.kv_transfer_config):
|
||||
pass
|
||||
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
|
||||
else:
|
||||
need_disable_hybrid_kv_cache_manager = True
|
||||
logger.warning(
|
||||
"Turning off hybrid kv cache manager because "
|
||||
@@ -1169,6 +1214,11 @@ class VllmConfig:
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = (
|
||||
need_disable_hybrid_kv_cache_manager
|
||||
)
|
||||
|
||||
else:
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = (
|
||||
need_disable_hybrid_kv_cache_manager
|
||||
)
|
||||
elif (
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager is False
|
||||
and need_disable_hybrid_kv_cache_manager
|
||||
@@ -1466,22 +1516,22 @@ class VllmConfig:
|
||||
if compile_range_end is not None:
|
||||
computed_compile_ranges_split_points.append(compile_range_end)
|
||||
|
||||
# # Add the compile ranges for flashinfer
|
||||
# if compilation_config.pass_config.fuse_allreduce_rms:
|
||||
# tp_size = self.parallel_config.tensor_parallel_size
|
||||
# max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
|
||||
# if max_size is not None:
|
||||
# max_token_num = max_size // (
|
||||
# self.model_config.get_hidden_size()
|
||||
# * self.model_config.dtype.itemsize
|
||||
# )
|
||||
# if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
# computed_compile_ranges_split_points.append(max_token_num)
|
||||
# else:
|
||||
# logger.debug(
|
||||
# "Max num batched tokens below allreduce-rms fusion threshold, "
|
||||
# "allreduce-rms fusion will be enabled for all num_tokens."
|
||||
# )
|
||||
# Add the compile ranges for flashinfer
|
||||
if compilation_config.pass_config.fuse_allreduce_rms:
|
||||
tp_size = self.parallel_config.tensor_parallel_size
|
||||
max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
|
||||
if max_size is not None:
|
||||
max_token_num = max_size // (
|
||||
self.model_config.get_hidden_size()
|
||||
* self.model_config.dtype.itemsize
|
||||
)
|
||||
if compile_range_end is not None and max_token_num < compile_range_end:
|
||||
computed_compile_ranges_split_points.append(max_token_num)
|
||||
else:
|
||||
logger.debug(
|
||||
"Max num batched tokens below allreduce-rms fusion threshold, "
|
||||
"allreduce-rms fusion will be enabled for all num_tokens."
|
||||
)
|
||||
|
||||
# Add the compile ranges for sequence parallelism
|
||||
if compilation_config.pass_config.enable_sp:
|
||||
@@ -1618,6 +1668,7 @@ class VllmConfig:
|
||||
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
|
||||
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
|
||||
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
|
||||
f"quantization={self.model_config.quantization}, "
|
||||
f"enforce_eager={self.model_config.enforce_eager}, "
|
||||
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
|
||||
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
|
||||
|
||||
@@ -9,5 +9,5 @@ from vllm.config.utils import config
|
||||
class WeightTransferConfig:
|
||||
"""Configuration for weight transfer during RL training."""
|
||||
|
||||
backend: Literal["nccl"] = "nccl"
|
||||
backend: Literal["nccl", "ipc"] = "nccl"
|
||||
"""The backend to use for weight transfer."""
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
import dataclasses
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
@@ -25,6 +25,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
cumem_available = False
|
||||
libcudart: Any = None
|
||||
try:
|
||||
from vllm.cumem_allocator import (
|
||||
init_module,
|
||||
@@ -41,9 +42,7 @@ except ModuleNotFoundError:
|
||||
init_module = None
|
||||
python_create_and_map = None
|
||||
python_unmap_and_release = None
|
||||
CudaRTLibrary = None
|
||||
lib_name = None
|
||||
libcudart = None
|
||||
|
||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||
HandleType = tuple[int, int, int, int]
|
||||
@@ -65,7 +64,8 @@ def unmap_and_release(allocation_handle: HandleType) -> None:
|
||||
|
||||
|
||||
def get_pluggable_allocator(
|
||||
python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
|
||||
python_malloc_fn: Callable[[HandleType], None],
|
||||
python_free_func: Callable[[int], HandleType],
|
||||
) -> torch.cuda.memory.CUDAPluggableAllocator:
|
||||
init_module(python_malloc_fn, python_free_func)
|
||||
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
@@ -76,8 +76,11 @@ def get_pluggable_allocator(
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool_with_allocator(
|
||||
python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
|
||||
) -> None:
|
||||
python_malloc_fn: Callable[[HandleType], None],
|
||||
python_free_func: Callable[[int], HandleType],
|
||||
) -> Iterator[
|
||||
tuple[torch.cuda.memory.MemPool, torch.cuda.memory.CUDAPluggableAllocator]
|
||||
]:
|
||||
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
||||
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
|
||||
with torch.cuda.memory.use_mem_pool(mem_pool):
|
||||
@@ -109,7 +112,7 @@ class CuMemAllocator:
|
||||
not work as expected.
|
||||
"""
|
||||
|
||||
instance: "CuMemAllocator" = None
|
||||
instance: "CuMemAllocator | None" = None
|
||||
default_tag: str = "default"
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -3,14 +3,13 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def naive_multicast(
|
||||
self,
|
||||
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_pplx(), (
|
||||
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install pplx_kernels."
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if self.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
dist.broadcast(
|
||||
uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
self.handle_cache._cache.clear()
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
# allow_nvlink_for_low_latency_mode=True,
|
||||
# allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
super().__init__(cpu_group, tcp_store_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
|
||||
@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
KiB = 1024
|
||||
MiB = 1024 * 1024
|
||||
# Max size for each world size in case symmetric memory is available
|
||||
# For different SM architectures
|
||||
@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
},
|
||||
}
|
||||
|
||||
# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
|
||||
# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
|
||||
# while custom_AR wins for mid-range sizes.
|
||||
#
|
||||
# Benchmark results (8 GPUs):
|
||||
# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
|
||||
# 32K - 64K: custom_AR wins
|
||||
# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
|
||||
#
|
||||
# Benchmark results (4 GPUs):
|
||||
# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
|
||||
# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
|
||||
# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
|
||||
#
|
||||
# The config defines ranges where custom_AR is preferred (symm_mem disabled).
|
||||
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
"min_world_size": 4,
|
||||
"thresholds": {
|
||||
4: 2 * MiB, # 2 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
# Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
|
||||
# NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
|
||||
"custom_ar_preferred_ranges": {
|
||||
4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
|
||||
8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
|
||||
},
|
||||
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
|
||||
}
|
||||
|
||||
|
||||
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Determine if NCCL symmetric memory allreduce should be used.
|
||||
|
||||
Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
|
||||
- Small tensors (≤16K): Lower latency than custom_AR
|
||||
- Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
|
||||
|
||||
Custom_AR is preferred for mid-range sizes where its P2P approach
|
||||
has lower overhead than the symm_mem copy-in/copy-out pattern.
|
||||
"""
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
return False
|
||||
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||
return True
|
||||
|
||||
tensor_size = input_tensor.nbytes
|
||||
custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
|
||||
world_size
|
||||
)
|
||||
|
||||
if custom_ar_range is not None:
|
||||
lower_bound, upper_bound = custom_ar_range
|
||||
# Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
|
||||
# Use custom_AR (not symm_mem) for mid-range sizes
|
||||
return tensor_size <= lower_bound or tensor_size >= upper_bound
|
||||
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
|
||||
|
||||
|
||||
|
||||
@@ -30,8 +30,9 @@ class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
def __init__(self, cpu_group, tcp_store_group=None):
|
||||
self.cpu_group = cpu_group
|
||||
self.tcp_store_group = tcp_store_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (
|
||||
@@ -48,12 +49,17 @@ class All2AllManagerBase:
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
if tcp_store_group is None:
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
else:
|
||||
self.internode = not all(
|
||||
in_the_same_node_as(tcp_store_group, source_rank=0)
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
@@ -122,11 +128,30 @@ class DeviceCommunicatorBase:
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
|
||||
# Check if this is a stateless process group
|
||||
from torch.distributed.distributed_c10d import _world
|
||||
|
||||
is_stateless = _world.pg_map.get(cpu_group, None) is None
|
||||
|
||||
if is_stateless:
|
||||
# For stateless groups, we can't use torch.distributed methods
|
||||
self.rank = cpu_group.rank()
|
||||
self.world_size = cpu_group.size()
|
||||
assert global_ranks is not None
|
||||
assert global_world_size is not None
|
||||
self.ranks = global_ranks
|
||||
self.global_rank = self.ranks[self.rank]
|
||||
self.global_world_size = global_world_size
|
||||
self.rank_in_group = self.rank
|
||||
else:
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
@@ -146,7 +171,7 @@ class DeviceCommunicatorBase:
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
@@ -175,9 +200,7 @@ class DeviceCommunicatorBase:
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
@@ -263,10 +286,9 @@ class DeviceCommunicatorBase:
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
torch.distributed.gather(
|
||||
input_, gather_list, dst=self.ranks[dst], group=self.device_group
|
||||
)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
@@ -292,6 +314,13 @@ class DeviceCommunicatorBase:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
@@ -360,3 +389,6 @@ class DeviceCommunicatorBase:
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
)
|
||||
and hasattr(torch.ops._C, "init_shm_manager")
|
||||
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
|
||||
and self._all_group_ranks_share_shm_group_name()
|
||||
):
|
||||
self.dist_module = _CPUSHMDistributed(self)
|
||||
elif unique_name.startswith("tp") or unique_name.startswith("pp"):
|
||||
logger.info(
|
||||
"CPU SHM communicator disabled for group %s: ranks do not share "
|
||||
"the same SHM group name, falling back to torch.distributed.",
|
||||
unique_name,
|
||||
)
|
||||
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive": # type: ignore[has-type]
|
||||
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
|
||||
def _all_group_ranks_share_shm_group_name(self) -> bool:
|
||||
"""
|
||||
CPUSHM requires all ranks in this group to agree on one SHM group name.
|
||||
This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
|
||||
"""
|
||||
local_name = _CPUSHMDistributed.make_group_name(self)
|
||||
names: list[str] = [""] * self.world_size
|
||||
torch.distributed.all_gather_object(
|
||||
names,
|
||||
local_name,
|
||||
group=self.device_group,
|
||||
)
|
||||
return len(set(names)) == 1
|
||||
|
||||
def all_reduce(self, input_):
|
||||
self.dist_module.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
def __init__(self, communicator: CpuCommunicator):
|
||||
self.communicator = communicator
|
||||
|
||||
self.group_name = self.make_group_name(communicator)
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
|
||||
@staticmethod
|
||||
def make_group_name(communicator: CpuCommunicator) -> str:
|
||||
instance_identifier = os.environ["VLLM_DIST_IDENT"]
|
||||
unique_name = communicator.unique_name
|
||||
instance_identifier = f"{instance_identifier}-{unique_name}"
|
||||
self.communicator = communicator
|
||||
|
||||
group_ranks = [str(rank) for rank in self.communicator.ranks]
|
||||
group_ranks = [str(rank) for rank in communicator.ranks]
|
||||
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
|
||||
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
self.handle = self._init_cpu_shm()
|
||||
return f"{instance_identifier}-{shm_group_identifier}-cpushm"
|
||||
|
||||
def _init_cpu_shm(self) -> int:
|
||||
thread_num_tensor = torch.tensor(
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import StatelessProcessGroup
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
import ixformer.distributed as ixfd
|
||||
import os
|
||||
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
global_ranks: list[int] | None = None,
|
||||
global_world_size: int | None = None,
|
||||
tcp_store_group: StatelessProcessGroup | None = None,
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
super().__init__(
|
||||
cpu_group,
|
||||
device,
|
||||
device_group,
|
||||
unique_name,
|
||||
global_ranks,
|
||||
global_world_size,
|
||||
)
|
||||
if "tp" not in unique_name:
|
||||
# custom allreduce or torch symm mem can be used only by tp
|
||||
use_custom_allreduce = False
|
||||
@@ -48,6 +59,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.use_flashinfer_allreduce = use_flashinfer_allreduce
|
||||
|
||||
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
|
||||
|
||||
# lazy import to avoid documentation build error
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
CustomAllreduce,
|
||||
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.pynccl_comm: PyNcclCommunicator | None = None
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
|
||||
device=self.device,
|
||||
)
|
||||
if is_symmetric_memory_enabled():
|
||||
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = NaiveAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
elif self.all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = AgRsAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
elif self.all2all_backend == "mori":
|
||||
from .all2all import MoriAll2AllManager
|
||||
|
||||
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.cpu_group, tcp_store_group
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return out
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.use_vllm_comm:
|
||||
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
|
||||
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is None or pynccl_comm.disabled:
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
assert pynccl_comm is not None
|
||||
out = pynccl_comm.all_reduce(input_)
|
||||
if out is None:
|
||||
# fall back to the default all-reduce using PyTorch.
|
||||
# this usually happens during testing.
|
||||
# when we run the model, allreduce only happens for the TP
|
||||
# group, where we always have either custom allreduce or pynccl.
|
||||
out = input_.clone()
|
||||
torch.distributed.all_reduce(out, group=self.device_group)
|
||||
return out
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
|
||||
world_size = self.world_size
|
||||
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# pynccl_comm.reduce_scatter(output, input_tensor)
|
||||
torch.distributed.reduce_scatter_tensor(output,
|
||||
input_tensor,
|
||||
group=self.device_group)
|
||||
# Perform reduce-scatter operation
|
||||
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
|
||||
|
||||
# Reshape before returning
|
||||
return output.movedim(0, dim).contiguous()
|
||||
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.send(tensor, dst)
|
||||
# else:
|
||||
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.send(tensor, self.ranks[dst], self.device_group)
|
||||
else:
|
||||
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
# pynccl_comm = self.pynccl_comm
|
||||
# if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
# pynccl_comm.recv(tensor, src)
|
||||
# else:
|
||||
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
if self.use_vllm_comm:
|
||||
ixfd.recv(tensor, self.ranks[src], self.device_group)
|
||||
else:
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all ranks."""
|
||||
if self.world_size == 1:
|
||||
return tensor
|
||||
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.broadcast(tensor, src)
|
||||
return tensor
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
def destroy(self):
|
||||
if self.pynccl_comm is not None:
|
||||
self.pynccl_comm = None
|
||||
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.fi_ar_comm = None
|
||||
if self.all2all_manager is not None:
|
||||
self.all2all_manager.destroy()
|
||||
self.all2all_manager = None
|
||||
self.all2all_manager = None # type: ignore[assignment]
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
# return self.all2all_manager.dispatch(
|
||||
# hidden_states,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# is_sequence_parallel,
|
||||
# extra_tensors=extra_tensors,
|
||||
# )
|
||||
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, extra_residual, router_logits)
|
||||
return hidden_states, extra_residual, router_logits
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
pynccl_comm.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
raise ValueError("No PyNCCL communicator found")
|
||||
|
||||
@@ -312,10 +312,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclSend(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
dst,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
if tensor.dtype in [
|
||||
torch.float8_e5m2,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
torch.float8_e5m2fnuz,
|
||||
]:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
|
||||
else:
|
||||
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
|
||||
self.nccl.ncclRecv(
|
||||
buffer_type(tensor.data_ptr()),
|
||||
tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype),
|
||||
nccl_dtype,
|
||||
src,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
|
||||
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
def batch_isend_irecv(self, p2p_ops: list, stream=None):
|
||||
if self.disabled:
|
||||
return
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
self.group_start()
|
||||
for op in p2p_ops:
|
||||
if op.op is torch.distributed.isend:
|
||||
self.send(op.tensor, op.group_peer, stream)
|
||||
elif op.op is torch.distributed.irecv:
|
||||
self.recv(op.tensor, op.group_peer, stream)
|
||||
|
||||
self.group_end()
|
||||
|
||||
0
vllm/distributed/elastic_ep/__init__.py
Normal file
0
vllm/distributed/elastic_ep/__init__.py
Normal file
529
vllm/distributed/elastic_ep/elastic_execute.py
Normal file
529
vllm/distributed/elastic_ep/elastic_execute.py
Normal file
@@ -0,0 +1,529 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import gc
|
||||
import weakref
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributed import P2POp
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.wrapper import reset_compile_wrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pcp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
from vllm.distributed.elastic_ep.standby_state import (
|
||||
create_standby_groups,
|
||||
get_standby_dp_group,
|
||||
get_standby_ep_group,
|
||||
pop_standby_groups,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
_replace_active_groups,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def batch_transfer_weights(
|
||||
model: nn.Module,
|
||||
is_sender: bool,
|
||||
peer_rank: int,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
||||
) -> None:
|
||||
device_comm = dp_group.device_communicator
|
||||
if device_comm is None:
|
||||
raise ValueError("No device communicator found")
|
||||
|
||||
expert_weights_set = set()
|
||||
for weight_group in expert_weights:
|
||||
for weight in weight_group:
|
||||
expert_weights_set.add(weight.data_ptr())
|
||||
|
||||
state_dict = model.state_dict()
|
||||
all_params = []
|
||||
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith("expert_map"):
|
||||
continue
|
||||
if param.data_ptr() not in expert_weights_set:
|
||||
all_params.append(param.data)
|
||||
|
||||
assert len(all_params) > 0
|
||||
p2p_ops = []
|
||||
for param in all_params:
|
||||
op = object.__new__(P2POp)
|
||||
if is_sender:
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = param
|
||||
else:
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = param
|
||||
op.group_peer = peer_rank
|
||||
p2p_ops.append(op)
|
||||
device_comm.batch_isend_irecv(p2p_ops)
|
||||
|
||||
|
||||
def broadcast_expert_mapping(
|
||||
physical_to_logical: torch.Tensor | None,
|
||||
num_local_physical_experts: int | None,
|
||||
num_logical_experts: int | None,
|
||||
dp_group: StatelessGroupCoordinator,
|
||||
device: torch.device,
|
||||
src_rank: int = 0,
|
||||
) -> tuple[torch.Tensor, int, int]:
|
||||
if dp_group.rank_in_group == src_rank:
|
||||
assert physical_to_logical is not None
|
||||
assert num_local_physical_experts is not None
|
||||
assert num_logical_experts is not None
|
||||
assert physical_to_logical.dtype == torch.int64
|
||||
shape_tensor = torch.tensor(
|
||||
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
|
||||
)
|
||||
metadata_tensor = torch.tensor(
|
||||
[num_local_physical_experts, num_logical_experts],
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
|
||||
|
||||
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
|
||||
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
|
||||
|
||||
if dp_group.rank_in_group != src_rank:
|
||||
assert device is not None
|
||||
physical_to_logical = torch.empty(
|
||||
tuple(shape_tensor.tolist()),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert physical_to_logical is not None
|
||||
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
|
||||
num_local_physical_experts = int(metadata_tensor[0].item())
|
||||
num_logical_experts = int(metadata_tensor[1].item())
|
||||
|
||||
return physical_to_logical, num_local_physical_experts, num_logical_experts
|
||||
|
||||
|
||||
class ElasticEPScalingExecutor:
|
||||
def __init__(self, worker):
|
||||
self.worker_ref = weakref.ref(worker)
|
||||
self.reconfig_request = None
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
worker = self.worker_ref()
|
||||
if worker is None:
|
||||
raise RuntimeError("Worker has been garbage collected")
|
||||
return worker
|
||||
|
||||
def execute(self, execute_method: str, *args, **kwargs):
|
||||
method = getattr(self, execute_method, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Unknown execute method: {execute_method}")
|
||||
return method(*args, **kwargs)
|
||||
|
||||
def create_standby_groups(
|
||||
self, reconfig_request: ReconfigureDistributedRequest
|
||||
) -> None:
|
||||
self.reconfig_request = reconfig_request
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
world_size = self.worker.vllm_config.parallel_config.world_size
|
||||
new_world_size_across_dp = world_size * new_dp_size
|
||||
updated_config = copy.copy(self.worker.vllm_config)
|
||||
updated_config.parallel_config = copy.deepcopy(
|
||||
self.worker.vllm_config.parallel_config
|
||||
)
|
||||
updated_config.parallel_config.data_parallel_size = new_dp_size
|
||||
with set_current_vllm_config(updated_config):
|
||||
create_standby_groups(
|
||||
new_dp_size=new_dp_size,
|
||||
new_world_size_across_dp=new_world_size_across_dp,
|
||||
master_ip=reconfig_request.new_data_parallel_master_ip,
|
||||
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
|
||||
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
|
||||
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
|
||||
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
|
||||
)
|
||||
self.worker.model_runner.eep_eplb_suppressed = True
|
||||
standby_ep_group = get_standby_ep_group()
|
||||
assert standby_ep_group is not None
|
||||
if standby_ep_group.rank == 0:
|
||||
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
|
||||
|
||||
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
# Broadcast old_dp_size to all workers in standby group
|
||||
if standby_dp_group.rank_in_group < old_dp_size:
|
||||
old_dp_size_tensor = torch.tensor(
|
||||
[old_dp_size], dtype=torch.int64, device="cpu"
|
||||
)
|
||||
else:
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
|
||||
old_dp_size_tensor, 0
|
||||
)
|
||||
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Sender-receiver pairing: the first new_workers % old_dp_size
|
||||
# senders get (k+1) contiguous receivers, the rest get k
|
||||
# receivers.
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if dp_rank < remainder:
|
||||
recv_begin = dp_rank * (num_dst_per_sender + 1)
|
||||
recv_end = recv_begin + num_dst_per_sender + 1
|
||||
else:
|
||||
recv_begin = (
|
||||
remainder * (num_dst_per_sender + 1)
|
||||
+ (dp_rank - remainder) * num_dst_per_sender
|
||||
)
|
||||
recv_end = recv_begin + num_dst_per_sender
|
||||
|
||||
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
for new_worker_rank in sorted(ranks_to_send):
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=True,
|
||||
peer_rank=new_worker_rank,
|
||||
dp_group=standby_dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def broadcast_expert_mapping(self) -> None:
|
||||
standby_dp_group = get_standby_dp_group()
|
||||
assert standby_dp_group is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_physical_experts = physical_to_logical.shape[1]
|
||||
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=physical_to_logical,
|
||||
num_local_physical_experts=num_local_physical_experts,
|
||||
num_logical_experts=num_logical_experts,
|
||||
dp_group=standby_dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
|
||||
def switch_and_remove(self) -> None:
|
||||
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
|
||||
|
||||
def switch_and_prepare(self) -> None:
|
||||
old_dp_size = get_dp_group().world_size
|
||||
old_ep_size = get_ep_group().world_size
|
||||
|
||||
_replace_active_groups(**pop_standby_groups())
|
||||
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
reconfig_request = self.reconfig_request
|
||||
assert reconfig_request is not None
|
||||
new_dp_size = reconfig_request.new_data_parallel_size
|
||||
new_ep_size = get_ep_group().world_size
|
||||
|
||||
parallel_config.data_parallel_size = new_dp_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
|
||||
# Reconfigure MoE modules with new EP size
|
||||
moe_modules = [
|
||||
module
|
||||
for module in self.worker.model_runner.model.modules()
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
num_local_experts = moe_modules[0].moe_config.num_local_experts
|
||||
assert all(
|
||||
module.moe_config.num_local_experts == num_local_experts
|
||||
for module in moe_modules
|
||||
), "All MoE modules must have the same number of experts"
|
||||
for module in moe_modules:
|
||||
module.moe_config.num_experts = num_local_experts * new_ep_size
|
||||
module.global_num_experts = module.moe_config.num_experts
|
||||
tp_size = get_tp_group().world_size
|
||||
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
sp_size = tp_size if is_sequence_parallel else 1
|
||||
module.moe_parallel_config = FusedMoEParallelConfig.make(
|
||||
tp_size_=tp_size,
|
||||
pcp_size_=get_pcp_group().world_size,
|
||||
dp_size_=get_dp_group().world_size,
|
||||
sp_size_=sp_size,
|
||||
vllm_parallel_config=parallel_config,
|
||||
)
|
||||
module.moe_config.moe_parallel_config = module.moe_parallel_config
|
||||
|
||||
# Update EPLB state
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
|
||||
num_physical_experts = num_local_experts * new_ep_size
|
||||
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_physical_experts - num_logical_experts
|
||||
)
|
||||
old_physical_to_logical = eplb_model_state.physical_to_logical_map
|
||||
num_moe_layers = old_physical_to_logical.shape[0]
|
||||
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
|
||||
if new_dp_size > old_dp_size:
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=old_physical_to_logical.dtype,
|
||||
device=old_physical_to_logical.device,
|
||||
)
|
||||
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
|
||||
old_physical_to_logical
|
||||
)
|
||||
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
|
||||
|
||||
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
|
||||
pad_size = num_physical_experts - old_num_physical_experts
|
||||
if new_dp_size > old_dp_size:
|
||||
assert pad_size > 0
|
||||
expanded_expert_load_pass = F.pad(
|
||||
eplb_model_state.expert_load_pass, (0, pad_size), value=0
|
||||
)
|
||||
expanded_expert_load_window = F.pad(
|
||||
eplb_model_state.expert_load_window, (0, pad_size), value=0
|
||||
)
|
||||
eplb_model_state.expert_load_pass = expanded_expert_load_pass
|
||||
eplb_model_state.expert_load_window = expanded_expert_load_window
|
||||
eplb_state.num_valid_physical_experts = old_num_physical_experts
|
||||
else:
|
||||
assert pad_size < 0
|
||||
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
|
||||
:, :num_physical_experts
|
||||
]
|
||||
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
|
||||
:, :, :num_physical_experts
|
||||
]
|
||||
eplb_state.num_valid_physical_experts = num_physical_experts
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
model.expert_weights = []
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
model.set_eplb_state(
|
||||
eplb_model_state.expert_load_pass,
|
||||
eplb_model_state.logical_to_physical_map,
|
||||
eplb_model_state.logical_replica_count,
|
||||
)
|
||||
model.update_physical_experts_metadata(
|
||||
num_physical_experts=num_physical_experts,
|
||||
num_local_physical_experts=num_local_experts,
|
||||
)
|
||||
# Force re-creation of the modular kernel (and all2all manager)
|
||||
# for the new EP size by resetting quant_method to base
|
||||
for module in moe_modules:
|
||||
if hasattr(module.quant_method, "old_quant_method"):
|
||||
module.quant_method = module.quant_method.old_quant_method
|
||||
module.runner = module._init_runner()
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.model)
|
||||
if (
|
||||
self.worker.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
# NOTE(yongji): when using stock torch.compile,
|
||||
# torch.compile is triggered during GPUModelRunner's load_model()
|
||||
# TODO(yongji):check do we need to re-trigger torch.compile here?
|
||||
# any changes to the tensor shapes in execution should already
|
||||
# be handled internally by torch.compile.
|
||||
backend = self.worker.vllm_config.compilation_config.init_backend(
|
||||
self.worker.vllm_config
|
||||
)
|
||||
compilation_counter.stock_torch_compile_count += 1
|
||||
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
|
||||
|
||||
# release all previously captured CUDA graphs
|
||||
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
|
||||
wrapper = self.worker.model_runner.model
|
||||
wrapper.concrete_cudagraph_entries = {}
|
||||
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
|
||||
raise RuntimeError("DBO is not yet supported in elastic EP")
|
||||
|
||||
multi_block_table = self.worker.model_runner.input_batch.block_table
|
||||
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
|
||||
for bt in multi_block_table.block_tables:
|
||||
saved_block_tables.append(
|
||||
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
|
||||
)
|
||||
multi_block_table.clear()
|
||||
|
||||
# reset the compile wrapper
|
||||
torch.compiler.reset()
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
reset_compile_wrapper(self.worker.model_runner.get_model())
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
unlock_workspace()
|
||||
self.worker.compile_or_warm_up_model()
|
||||
lock_workspace()
|
||||
|
||||
for bt, (saved_gpu, saved_cpu) in zip(
|
||||
multi_block_table.block_tables, saved_block_tables
|
||||
):
|
||||
bt.block_table.gpu.copy_(saved_gpu)
|
||||
bt.block_table.cpu.copy_(saved_cpu)
|
||||
|
||||
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Starting expert resharding...")
|
||||
|
||||
eplb_state = self.worker.model_runner.eplb_state
|
||||
assert eplb_state is not None
|
||||
|
||||
model_config = self.worker.model_runner.model_config
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
is_async_enabled = eplb_state.is_async
|
||||
eplb_state.is_async = False
|
||||
if new_dp_size is None:
|
||||
eplb_state.rearrange()
|
||||
else:
|
||||
# scale down
|
||||
parallel_config = self.worker.vllm_config.parallel_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
old_ep_size = parallel_config.data_parallel_size * tp_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
|
||||
rank_mapping = {
|
||||
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
|
||||
for old_ep_rank in range(old_ep_size)
|
||||
}
|
||||
|
||||
eplb_state.rearrange(rank_mapping=rank_mapping)
|
||||
# NOTE(yongji): check whether we need to synchronize here
|
||||
torch.cuda.synchronize()
|
||||
# reset expert_rearrangement_step to ensure all ranks are synchronized
|
||||
eplb_state.expert_rearrangement_step = 0
|
||||
eplb_state.num_valid_physical_experts = (
|
||||
eplb_model_state.physical_to_logical_map.shape[1]
|
||||
)
|
||||
eplb_state.is_async = is_async_enabled
|
||||
self.worker.model_runner.eep_eplb_suppressed = False
|
||||
if get_ep_group().rank == 0:
|
||||
logger.info("[Elastic EP] Expert resharding completed")
|
||||
|
||||
def receive_weights(self) -> None:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
new_dp_size = dp_group.world_size
|
||||
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# Receive old_dp_size broadcasted during transfer_weights
|
||||
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
|
||||
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
|
||||
old_dp_size = int(old_dp_size_tensor[0].item())
|
||||
|
||||
# Calculate which existing worker will send to this new worker
|
||||
num_new_workers = new_dp_size - old_dp_size
|
||||
new_worker_idx = dp_rank - old_dp_size
|
||||
num_dst_per_sender = num_new_workers // old_dp_size
|
||||
remainder = num_new_workers % old_dp_size
|
||||
|
||||
if new_worker_idx < remainder * (num_dst_per_sender + 1):
|
||||
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
|
||||
else:
|
||||
sender_rank = (
|
||||
remainder
|
||||
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
|
||||
// num_dst_per_sender
|
||||
)
|
||||
|
||||
model = self.worker.model_runner.get_model()
|
||||
batch_transfer_weights(
|
||||
model=model,
|
||||
is_sender=False,
|
||||
peer_rank=sender_rank,
|
||||
dp_group=dp_group,
|
||||
expert_weights=model.expert_weights,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
|
||||
dp_group = get_dp_group()
|
||||
assert isinstance(dp_group, StatelessGroupCoordinator)
|
||||
physical_to_logical, num_local_physical_experts, num_logical_experts = (
|
||||
broadcast_expert_mapping(
|
||||
physical_to_logical=None,
|
||||
num_local_physical_experts=None,
|
||||
num_logical_experts=None,
|
||||
dp_group=dp_group,
|
||||
src_rank=0,
|
||||
device=self.worker.device,
|
||||
)
|
||||
)
|
||||
num_moe_layers = physical_to_logical.shape[0]
|
||||
new_dp_size = get_dp_group().world_size
|
||||
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
|
||||
new_ep_size = new_dp_size * tp_size
|
||||
expanded_physical_to_logical = torch.full(
|
||||
(num_moe_layers, num_local_physical_experts * new_ep_size),
|
||||
-1,
|
||||
dtype=physical_to_logical.dtype,
|
||||
device=physical_to_logical.device,
|
||||
)
|
||||
old_num_physical_experts = physical_to_logical.shape[1]
|
||||
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
|
||||
return (
|
||||
expanded_physical_to_logical,
|
||||
num_logical_experts,
|
||||
old_num_physical_experts,
|
||||
)
|
||||
|
||||
def prepare_new_worker(self) -> None:
|
||||
with set_current_vllm_config(self.worker.vllm_config):
|
||||
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())
|
||||
563
vllm/distributed/elastic_ep/elastic_state.py
Normal file
563
vllm/distributed/elastic_ep/elastic_state.py
Normal file
@@ -0,0 +1,563 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import enum
|
||||
import time
|
||||
import weakref
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import (
|
||||
sched_yield,
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import (
|
||||
EEPNotificationType,
|
||||
ReconfigureDistributedRequest,
|
||||
ReconfigureRankType,
|
||||
)
|
||||
from vllm.v1.engine.core import DPEngineCoreProc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WorkerType = Literal["existing", "new", "removing"]
|
||||
|
||||
|
||||
class ScaleUpExistingEngineState(enum.IntEnum):
|
||||
WAIT_NEW_CORE_ENGINES_INIT = 0
|
||||
CREATE_STANDBY_GROUPS = 1
|
||||
TRANSFER_EXPERT_MAPPING = 2
|
||||
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
|
||||
TRANSFER_WEIGHTS = 4
|
||||
SYNC_KV_CACHE_MEMORY_SIZE = 5
|
||||
SWITCH_AND_PREPARE = 6
|
||||
EPLB_RESHUFFLE = 7
|
||||
COMPLETE = 8
|
||||
|
||||
|
||||
class ScaleUpNewEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class ScaleDownRemainingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
SWITCH_AND_PREPARE = 2
|
||||
COMPLETE = 3
|
||||
|
||||
|
||||
class ScaleDownRemovingEngineState(enum.IntEnum):
|
||||
PREPARE = 0
|
||||
EPLB_RESHUFFLE = 1
|
||||
COMPLETE = 2
|
||||
|
||||
|
||||
class _BarrierTimeoutError(RuntimeError):
|
||||
"""
|
||||
Exception raised for timeout
|
||||
in the first stage of our two-staged
|
||||
TCPStore based barrier to synchronize the
|
||||
execution of all engines in the DP group.
|
||||
"""
|
||||
|
||||
|
||||
class ElasticEPScalingState:
|
||||
def __init__(
|
||||
self,
|
||||
model_executor: "Executor",
|
||||
engine_core: "DPEngineCoreProc",
|
||||
vllm_config: "VllmConfig",
|
||||
new_parallel_config: ParallelConfig,
|
||||
worker_type: WorkerType,
|
||||
scale_type: Literal["scale_up", "scale_down"],
|
||||
reconfig_request: ReconfigureDistributedRequest | None = None,
|
||||
):
|
||||
self.model_executor_ref = weakref.ref(model_executor)
|
||||
self.engine_core_ref = weakref.ref(engine_core)
|
||||
self.vllm_config = vllm_config
|
||||
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
|
||||
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
|
||||
self.new_parallel_config: ParallelConfig = new_parallel_config
|
||||
self.new_dp_group: torch.distributed.ProcessGroup | None = (
|
||||
self.engine_core.dp_group if worker_type == "new" else None
|
||||
)
|
||||
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
|
||||
self.worker_type = worker_type
|
||||
self.scale_type = scale_type
|
||||
self.reconfig_request = reconfig_request
|
||||
|
||||
if scale_type == "scale_up":
|
||||
self.state = (
|
||||
ScaleUpNewEngineState.PREPARE
|
||||
if worker_type == "new"
|
||||
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
)
|
||||
else:
|
||||
self.state = (
|
||||
ScaleDownRemovingEngineState.PREPARE
|
||||
if worker_type == "removing"
|
||||
else ScaleDownRemainingEngineState.PREPARE
|
||||
)
|
||||
|
||||
@property
|
||||
def model_executor(self) -> "Executor":
|
||||
model_executor = self.model_executor_ref()
|
||||
if model_executor is None:
|
||||
raise RuntimeError("Model executor has been garbage collected")
|
||||
return model_executor
|
||||
|
||||
@property
|
||||
def engine_core(self) -> "DPEngineCoreProc":
|
||||
engine_core = self.engine_core_ref()
|
||||
if engine_core is None:
|
||||
raise RuntimeError("Engine core has been garbage collected")
|
||||
return engine_core
|
||||
|
||||
def progress(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self._progress_new_engine()
|
||||
if self.worker_type == "new"
|
||||
else self._progress_existing_engine()
|
||||
)
|
||||
return (
|
||||
self._progress_removing_engine()
|
||||
if self.worker_type == "removing"
|
||||
else self._progress_remaining_engine()
|
||||
)
|
||||
|
||||
def _execute_tcp_store_barrier(
|
||||
self, dp_store, group_rank, group_size, barrier_id, timeout=None
|
||||
):
|
||||
arrival_key = f"arrival_{barrier_id}_{group_rank}"
|
||||
dp_store.set(arrival_key, b"1")
|
||||
|
||||
start_time = time.time()
|
||||
processes_arrived: set[int] = set()
|
||||
|
||||
while len(processes_arrived) < group_size:
|
||||
if (
|
||||
timeout is not None
|
||||
and time.time() - start_time > timeout.total_seconds()
|
||||
):
|
||||
raise _BarrierTimeoutError(
|
||||
f"Barrier timed out after {timeout.total_seconds()} seconds"
|
||||
)
|
||||
|
||||
for i in range(group_size):
|
||||
if i in processes_arrived:
|
||||
continue
|
||||
|
||||
key = f"arrival_{barrier_id}_{i}"
|
||||
present = dp_store.check([key])
|
||||
if present:
|
||||
processes_arrived.add(i)
|
||||
|
||||
if len(processes_arrived) < group_size:
|
||||
sched_yield()
|
||||
|
||||
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
|
||||
"""
|
||||
Execute a two-staged barrier to synchronize all engines in the DP group.
|
||||
|
||||
Some DP EngineCores may receive the reconfiguration notifications
|
||||
later than others, and already proceed to engine step (model forward)
|
||||
in the busy loop.
|
||||
In this case, EngineCores that already proceed to reconfiguration
|
||||
should skip reconfiguration and execute model forward for one more
|
||||
step, so in the next step, all EngineCores will be synchronized.
|
||||
We use a two-staged barrier to achieve this. The first time each
|
||||
EngineCore executes the barrier, if a timeout is reached before the
|
||||
barrier completes, that means some EngineCores have already entered
|
||||
engine step. The EngineCores that timed out will then proceed to
|
||||
engine step, and will synchronize with the other EngineCores in the
|
||||
next step with a barrier without timeout.
|
||||
"""
|
||||
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
|
||||
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
|
||||
assert dp_group is not None
|
||||
|
||||
group_rank = dp_group.rank()
|
||||
group_size = dp_group.size()
|
||||
barrier_id = f"eep_barrier_{barrier_name}"
|
||||
sync_key = f"{barrier_id}_sync"
|
||||
|
||||
# TODO(yongji): figure out appropriate timeout for the barrier
|
||||
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
|
||||
|
||||
try:
|
||||
self._execute_tcp_store_barrier(
|
||||
dp_store, group_rank, group_size, barrier_id, timeout=timeout
|
||||
)
|
||||
torch.distributed.barrier(dp_group)
|
||||
if group_rank == 0:
|
||||
dp_store.delete_key(sync_key)
|
||||
for i in range(group_size):
|
||||
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
|
||||
return True
|
||||
except _BarrierTimeoutError as e:
|
||||
if timeout is None:
|
||||
raise RuntimeError("Unexpected timeout encountered") from e
|
||||
dp_store.compare_set(sync_key, "", b"1")
|
||||
return False
|
||||
|
||||
def _progress_existing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
|
||||
# NOTE(yongji): wait for all existing workers to receive the request
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="create_standby_groups"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._create_standby_groups()
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
|
||||
self._transfer_expert_mapping()
|
||||
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
|
||||
return False
|
||||
|
||||
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="transfer_weights"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._transfer_weights()
|
||||
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
|
||||
self._sync_kv_cache_memory_size()
|
||||
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
|
||||
self._switch_and_prepare()
|
||||
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
|
||||
assert self.new_dp_group is not None
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.new_dp_group.rank() == 0:
|
||||
self.new_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpExistingEngineState.COMPLETE
|
||||
self._update_parallel_config()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_new_engine(self) -> bool:
|
||||
state = self.state
|
||||
assert self.new_dp_group is not None
|
||||
|
||||
if state == ScaleUpNewEngineState.PREPARE:
|
||||
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
|
||||
torch.distributed.all_reduce(
|
||||
tensor,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=self.new_dp_group,
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
|
||||
self.new_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.new_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.new_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=True, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.new_dp_group.rank() > 0
|
||||
self._eplb_reshuffle()
|
||||
self.state = ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleUpNewEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_remaining_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemainingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
if self.old_dp_group.rank() == 0:
|
||||
self.old_dp_store.delete_key("eep_barrier_engine_count")
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
|
||||
# NOTE(yongji): currently, after EPLB reshuffle
|
||||
# that redistributes experts to remaining workers, workers
|
||||
# to be removed will immediately initiate shutdown;
|
||||
# existing workers can no longer execute forward steps using
|
||||
# the old setup. In the future, we may keep
|
||||
# the removing workers alive a bit longer,
|
||||
# e.g., to drain in-batch requests.
|
||||
self._create_standby_groups()
|
||||
self._switch_and_prepare()
|
||||
self._update_parallel_config()
|
||||
self.state = ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def _progress_removing_engine(self) -> bool:
|
||||
state = self.state
|
||||
|
||||
if state == ScaleDownRemovingEngineState.PREPARE:
|
||||
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
return True
|
||||
|
||||
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
|
||||
if (
|
||||
int(self.old_dp_store.get("eep_barrier_engine_count"))
|
||||
< self.old_dp_group.size()
|
||||
):
|
||||
return False
|
||||
if not self._staged_barrier(
|
||||
use_new_group=False, barrier_name="eplb_reshuffle"
|
||||
):
|
||||
return False
|
||||
assert self.old_dp_group.rank() > 0
|
||||
self._eplb_reshuffle_before_scale_down()
|
||||
self._switch_and_remove()
|
||||
self.state = ScaleDownRemovingEngineState.COMPLETE
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.SHUTDOWN_COMPLETE
|
||||
)
|
||||
self.engine_core.shutdown()
|
||||
return True
|
||||
|
||||
else:
|
||||
assert self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
return True
|
||||
|
||||
def handle_notification(self, notification_type: EEPNotificationType):
|
||||
assert self.worker_type != "new"
|
||||
if (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
|
||||
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
|
||||
elif (
|
||||
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
|
||||
and self.state
|
||||
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
|
||||
):
|
||||
self.old_dp_store.add("eep_barrier_engine_count", 1)
|
||||
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
if self.scale_type == "scale_up":
|
||||
return (
|
||||
self.state == ScaleUpNewEngineState.COMPLETE
|
||||
if self.worker_type == "new"
|
||||
else self.state == ScaleUpExistingEngineState.COMPLETE
|
||||
)
|
||||
return (
|
||||
self.state == ScaleDownRemovingEngineState.COMPLETE
|
||||
if self.worker_type == "removing"
|
||||
else self.state == ScaleDownRemainingEngineState.COMPLETE
|
||||
)
|
||||
|
||||
def _create_standby_groups(self):
|
||||
self.new_dp_group, self.new_dp_store = (
|
||||
self.new_parallel_config.stateless_init_dp_group(return_store=True)
|
||||
)
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Created standby communication groups")
|
||||
|
||||
def _transfer_weights(self):
|
||||
assert self.reconfig_request is not None
|
||||
old_dp_size = self.old_dp_group.size()
|
||||
new_dp_size = self.reconfig_request.new_data_parallel_size
|
||||
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Transferred weights to new workers")
|
||||
|
||||
def _transfer_expert_mapping(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("broadcast_expert_mapping",)
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
|
||||
|
||||
def _sync_kv_cache_memory_size(self):
|
||||
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
|
||||
assert self.new_dp_group is not None
|
||||
ParallelConfig.sync_kv_cache_memory_size(
|
||||
self.new_dp_group,
|
||||
self.engine_core.available_gpu_memory_for_kv_cache,
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
|
||||
|
||||
def _switch_and_prepare(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_prepare",)
|
||||
)
|
||||
old_dp_group = self.old_dp_group
|
||||
stateless_destroy_torch_distributed_process_group(old_dp_group)
|
||||
assert self.new_dp_group is not None
|
||||
new_dp_group = self.new_dp_group
|
||||
self.engine_core.dp_group = new_dp_group
|
||||
self.engine_core.dp_rank = new_dp_group.rank()
|
||||
self.engine_core.dp_store = self.new_dp_store
|
||||
engines_running = int(self.engine_core.engines_running)
|
||||
current_wave = self.engine_core.current_wave
|
||||
step_counter = self.engine_core.step_counter
|
||||
tensor = torch.tensor(
|
||||
[engines_running, current_wave, step_counter],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
|
||||
)
|
||||
data = tensor.tolist()
|
||||
self.engine_core.engines_running = bool(data[0])
|
||||
self.engine_core.current_wave = int(data[1])
|
||||
self.engine_core.step_counter = int(data[2])
|
||||
if new_dp_group.rank() == 0:
|
||||
self.engine_core._eep_send_engine_core_notification(
|
||||
EEPNotificationType.RECONFIGURE_FINISHED
|
||||
)
|
||||
logger.info("[Elastic EP] Switched to new setup")
|
||||
|
||||
def _eplb_reshuffle(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
|
||||
)
|
||||
assert self.new_dp_group is not None
|
||||
if self.new_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _eplb_reshuffle_before_scale_down(self):
|
||||
assert self.reconfig_request is not None
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute",
|
||||
args=(
|
||||
"perform_eplb_reshuffle",
|
||||
self.reconfig_request.new_data_parallel_size,
|
||||
),
|
||||
)
|
||||
if self.old_dp_group.rank() == 0:
|
||||
logger.info("[Elastic EP] EPLB reshuffle completed")
|
||||
|
||||
def _switch_and_remove(self):
|
||||
self.model_executor.collective_rpc(
|
||||
"elastic_ep_execute", args=("switch_and_remove",)
|
||||
)
|
||||
|
||||
def _update_parallel_config(self):
|
||||
assert self.reconfig_request is not None
|
||||
reconfig_request = self.reconfig_request
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
|
||||
if (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
!= ReconfigureRankType.KEEP_CURRENT_RANK
|
||||
):
|
||||
parallel_config.data_parallel_rank_local = (
|
||||
reconfig_request.new_data_parallel_rank_local
|
||||
)
|
||||
parallel_config.data_parallel_master_ip = (
|
||||
reconfig_request.new_data_parallel_master_ip
|
||||
)
|
||||
parallel_config.data_parallel_master_port = (
|
||||
reconfig_request.new_data_parallel_master_port
|
||||
)
|
||||
parallel_config._data_parallel_master_port_list = (
|
||||
reconfig_request.new_data_parallel_master_port_list
|
||||
)
|
||||
parallel_config._stateless_world_group_port_list = (
|
||||
reconfig_request.new_stateless_world_group_port_list
|
||||
)
|
||||
parallel_config._stateless_dp_group_port_list = (
|
||||
reconfig_request.new_stateless_dp_group_port_list
|
||||
)
|
||||
parallel_config._stateless_ep_group_port_list = (
|
||||
reconfig_request.new_stateless_ep_group_port_list
|
||||
)
|
||||
parallel_config._stateless_eplb_group_port_list = (
|
||||
reconfig_request.new_stateless_eplb_group_port_list
|
||||
)
|
||||
117
vllm/distributed/elastic_ep/standby_state.py
Normal file
117
vllm/distributed/elastic_ep/standby_state.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (
|
||||
_init_stateless_group,
|
||||
_node_count,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_WORLD_NODE_COUNT: int | None = None
|
||||
_STANDBY_DP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EP: StatelessGroupCoordinator | None = None
|
||||
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
|
||||
|
||||
|
||||
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_DP
|
||||
|
||||
|
||||
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EP
|
||||
|
||||
|
||||
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_EPLB
|
||||
|
||||
|
||||
def get_standby_world_group() -> StatelessGroupCoordinator | None:
|
||||
return _STANDBY_WORLD
|
||||
|
||||
|
||||
def create_standby_groups(
|
||||
new_dp_size: int,
|
||||
new_world_size_across_dp: int,
|
||||
master_ip: str,
|
||||
world_group_ports: list[list[int]],
|
||||
dp_group_ports: list[list[int]],
|
||||
ep_group_ports: list[list[int]],
|
||||
eplb_group_ports: list[list[int]] | None = None,
|
||||
backend: str | None = None,
|
||||
) -> None:
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
|
||||
world_group = get_world_group()
|
||||
assert isinstance(world_group, StatelessGroupCoordinator)
|
||||
backend = backend or world_group.backend
|
||||
|
||||
standby_world_ranks = [list(range(new_world_size_across_dp))]
|
||||
_STANDBY_WORLD = _init_stateless_group(
|
||||
standby_world_ranks,
|
||||
"world",
|
||||
world_group_ports,
|
||||
master_ip,
|
||||
backend,
|
||||
use_device_communicator=False,
|
||||
)
|
||||
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
|
||||
|
||||
tp_size = get_tp_group().world_size
|
||||
pp_size = get_pp_group().world_size
|
||||
|
||||
all_ranks = torch.arange(new_world_size_across_dp).reshape(
|
||||
-1, new_dp_size, pp_size, tp_size
|
||||
)
|
||||
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
|
||||
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
|
||||
_STANDBY_DP = _init_stateless_group(
|
||||
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
standby_ep_ranks = (
|
||||
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
|
||||
)
|
||||
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
|
||||
_STANDBY_EP = _init_stateless_group(
|
||||
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
if eplb_group_ports is not None:
|
||||
_STANDBY_EPLB = _init_stateless_group(
|
||||
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
|
||||
)
|
||||
|
||||
|
||||
def pop_standby_groups() -> dict:
|
||||
"""Return all standby groups and clear the standby state."""
|
||||
global \
|
||||
_STANDBY_WORLD, \
|
||||
_STANDBY_WORLD_NODE_COUNT, \
|
||||
_STANDBY_DP, \
|
||||
_STANDBY_EP, \
|
||||
_STANDBY_EPLB
|
||||
|
||||
result = dict(
|
||||
world=_STANDBY_WORLD,
|
||||
dp=_STANDBY_DP,
|
||||
ep=_STANDBY_EP,
|
||||
eplb=_STANDBY_EPLB,
|
||||
node_count=_STANDBY_WORLD_NODE_COUNT,
|
||||
)
|
||||
_STANDBY_WORLD = None
|
||||
_STANDBY_WORLD_NODE_COUNT = None
|
||||
_STANDBY_DP = None
|
||||
_STANDBY_EP = None
|
||||
_STANDBY_EPLB = None
|
||||
return result
|
||||
@@ -24,7 +24,6 @@ logger = init_logger(__name__)
|
||||
|
||||
def start_async_worker(
|
||||
state: "EplbState",
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
is_profile: bool = False,
|
||||
) -> threading.Thread:
|
||||
eplb_group = get_eplb_group().device_group
|
||||
@@ -45,7 +44,6 @@ def start_async_worker(
|
||||
eplb_group=eplb_group,
|
||||
cuda_stream=cuda_stream,
|
||||
is_profile=is_profile,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - diagnostic path
|
||||
@@ -107,7 +105,6 @@ async def transfer_run_periodically(
|
||||
eplb_group: ProcessGroup,
|
||||
cuda_stream: torch.cuda.Stream,
|
||||
is_profile: bool = False,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> None:
|
||||
while True:
|
||||
await asyncio.to_thread(state.rearrange_event.wait)
|
||||
@@ -176,7 +173,6 @@ async def transfer_run_periodically(
|
||||
ep_group=eplb_group,
|
||||
is_profile=is_profile,
|
||||
cuda_stream=cuda_stream,
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
event = torch.cuda.Event(blocking=False)
|
||||
cuda_stream.record_event(event)
|
||||
|
||||
@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_node_count,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
@@ -159,7 +160,7 @@ class EplbModelState:
|
||||
|
||||
NOTE: The expert_load_view now records load for all physical experts
|
||||
rather than just local experts. This ensures consistent load statistics
|
||||
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
|
||||
across different dispatch methods (naive all-to-all, DeepEP).
|
||||
The recorded load will be multiplied by dp_size when using naive all-to-all
|
||||
due to each DP rank contributing the same token set to the calculation.
|
||||
See:
|
||||
@@ -302,6 +303,14 @@ class EplbState:
|
||||
"""
|
||||
CUDA device index for the async EPLB worker thread.
|
||||
"""
|
||||
self.num_valid_physical_experts: int = 0
|
||||
"""
|
||||
Number of valid physical experts.
|
||||
This is the number of physical experts that are
|
||||
actually mapped to logical experts. In elastic EP,
|
||||
newly started EP ranks may not have physical experts
|
||||
mapped yet.
|
||||
"""
|
||||
if self.device.type == "cuda":
|
||||
self.cuda_device_index = self.device.index
|
||||
if self.cuda_device_index is None and torch.cuda.is_available():
|
||||
@@ -367,9 +376,6 @@ class EplbState:
|
||||
self,
|
||||
model: MixtureOfExperts,
|
||||
model_config: ModelConfig,
|
||||
global_expert_load: torch.Tensor | None = None,
|
||||
old_global_expert_indices: torch.Tensor | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
):
|
||||
"""
|
||||
Build the initial EPLB state.
|
||||
@@ -462,75 +468,15 @@ class EplbState:
|
||||
)
|
||||
self.expert_rearrangement_step_interval = eplb_step_interval
|
||||
|
||||
# Set the policy based on the selected eplb algorithm type.
|
||||
policy_type = self.parallel_config.eplb_config.policy
|
||||
self.policy = EPLB_POLICIES[policy_type]
|
||||
logger.debug("Selected EPLB policy: %s", policy_type)
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (
|
||||
model.num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
)
|
||||
assert global_expert_load.dtype == torch.int64
|
||||
|
||||
num_replicas = model.num_physical_experts
|
||||
num_groups = model.num_expert_groups
|
||||
num_nodes = get_node_count()
|
||||
num_gpus = ep_group.size()
|
||||
|
||||
if num_gpus % num_nodes != 0:
|
||||
num_nodes = 1
|
||||
logger.warning_once(
|
||||
f"num_gpus % num_nodes != 0, "
|
||||
"not using hierarchical rearrangement algorithm.\n"
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
# Get new expert mappings
|
||||
(
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
)
|
||||
|
||||
max_physical_slots = new_logical_to_physical_map.shape[-1]
|
||||
assert max_physical_slots <= logical_to_physical_map.shape[-1]
|
||||
new_logical_to_physical_map = torch.nn.functional.pad(
|
||||
new_logical_to_physical_map,
|
||||
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
|
||||
value=-1,
|
||||
)
|
||||
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
|
||||
logical_to_physical_map.copy_(new_logical_to_physical_map)
|
||||
logical_replica_count.copy_(new_logical_replica_count)
|
||||
else:
|
||||
new_physical_to_logical_map = None
|
||||
|
||||
new_logical_to_physical_map = None
|
||||
|
||||
new_logical_replica_count = None
|
||||
model.set_eplb_state(
|
||||
expert_load_pass,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
if global_expert_load is not None:
|
||||
rearrange_expert_weights_inplace(
|
||||
old_global_expert_indices,
|
||||
new_physical_to_logical_map,
|
||||
model.expert_weights,
|
||||
ep_group,
|
||||
False,
|
||||
rank_mapping,
|
||||
)
|
||||
self.expert_rearrangement_step = 0
|
||||
|
||||
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
|
||||
|
||||
@@ -561,11 +507,12 @@ class EplbState:
|
||||
recv_dst_rows=np.array([]),
|
||||
),
|
||||
cuda_device_index=self.cuda_device_index,
|
||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||
new_logical_to_physical_map=new_logical_to_physical_map,
|
||||
new_logical_replica_count=new_logical_replica_count,
|
||||
new_physical_to_logical_map=None,
|
||||
new_logical_to_physical_map=None,
|
||||
new_logical_replica_count=None,
|
||||
)
|
||||
self.model_states[model_config.compute_hash()] = model_state
|
||||
self.num_valid_physical_experts = model.num_physical_experts
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -696,8 +643,6 @@ class EplbState:
|
||||
def rearrange(
|
||||
self,
|
||||
is_profile: bool = False,
|
||||
execute_shuffle: bool = True,
|
||||
global_expert_loads: list[torch.Tensor] | None = None,
|
||||
rank_mapping: dict[int, int] | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
@@ -707,12 +652,6 @@ class EplbState:
|
||||
is_profile (bool): If `True`, perform a dummy rearrangement.
|
||||
This is used in `profile_run` to reserve enough memory,
|
||||
no memory movement will be performed. Default is False.
|
||||
execute_shuffle (bool): If `True`, execute the shuffle
|
||||
in elastic expert parallel (EEP). Default is True.
|
||||
global_expert_loads (list[torch.Tensor] | None): The global expert
|
||||
loads when scaling is done in EEP.
|
||||
List of expert loads for the main and drafter
|
||||
(when spec decode is used) models.
|
||||
rank_mapping (dict[int, int] | None): The rank mapping
|
||||
when scaling is done in EEP.
|
||||
"""
|
||||
@@ -734,18 +673,12 @@ class EplbState:
|
||||
"(profile)" if is_profile else "",
|
||||
)
|
||||
|
||||
if global_expert_loads is None:
|
||||
# Map the physical expert load to global logical experts
|
||||
global_expert_load_windows = []
|
||||
if not execute_shuffle:
|
||||
num_models = torch.tensor(
|
||||
[len(self.model_states)], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
num_models, group=get_ep_group().cpu_group, group_src=0
|
||||
)
|
||||
|
||||
for eplb_model_state in self.model_states.values():
|
||||
expert_load_window = eplb_model_state.expert_load_window[
|
||||
:, :, : self.num_valid_physical_experts
|
||||
]
|
||||
logical_expert_load_window = torch.zeros(
|
||||
self.expert_load_window_size,
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
@@ -755,46 +688,19 @@ class EplbState:
|
||||
)
|
||||
logical_expert_load_window.scatter_add_(
|
||||
dim=-1,
|
||||
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
|
||||
.expand_as(eplb_model_state.expert_load_window)
|
||||
index=eplb_model_state.physical_to_logical_map[
|
||||
:, : self.num_valid_physical_experts
|
||||
]
|
||||
.unsqueeze(0)
|
||||
.expand_as(expert_load_window)
|
||||
.long(),
|
||||
src=eplb_model_state.expert_load_window,
|
||||
)
|
||||
|
||||
if not execute_shuffle:
|
||||
metadata = torch.tensor(
|
||||
[
|
||||
eplb_model_state.model.num_moe_layers,
|
||||
eplb_model_state.model.num_logical_experts,
|
||||
eplb_model_state.physical_to_logical_map.shape[1],
|
||||
],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
metadata, group=get_ep_group().cpu_group, group_src=0
|
||||
src=expert_load_window,
|
||||
)
|
||||
|
||||
global_expert_load_window = logical_expert_load_window.sum(dim=0)
|
||||
global_expert_load_windows.append(global_expert_load_window)
|
||||
# Perform all-reduce to get the expert load across all ranks for each model
|
||||
global_expert_load_windows = self._allreduce_list(
|
||||
global_expert_load_windows
|
||||
)
|
||||
if not execute_shuffle:
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
# (num_moe_layers, old_num_physical_experts)
|
||||
old_global_expert_indices = eplb_model_state.physical_to_logical_map
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices, group=ep_group, group_src=0
|
||||
)
|
||||
if not execute_shuffle:
|
||||
return global_expert_load_windows
|
||||
else:
|
||||
assert execute_shuffle
|
||||
global_expert_load_windows = global_expert_loads
|
||||
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
|
||||
|
||||
# TODO(bowen): Treat differently for prefill and decode nodes
|
||||
eplb_model_state = next(iter(self.model_states.values()))
|
||||
@@ -806,8 +712,10 @@ class EplbState:
|
||||
# NOTE(yongji): scale down, we need to rebalance the experts on
|
||||
# remaining GPUs, transfer the experts while we haven't shutdown
|
||||
# the GPUs to be released.
|
||||
cpu_group = get_ep_group().cpu_group
|
||||
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
|
||||
coordinator = get_ep_group()
|
||||
assert isinstance(coordinator, StatelessGroupCoordinator)
|
||||
tcp_store_group = coordinator.tcp_store_group
|
||||
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
|
||||
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
|
||||
num_replicas = (
|
||||
num_replicas // ep_group.size() * num_gpus
|
||||
@@ -933,7 +841,6 @@ class EplbState:
|
||||
if self.async_worker is None:
|
||||
self.async_worker = start_async_worker(
|
||||
self,
|
||||
rank_mapping=rank_mapping,
|
||||
is_profile=is_profile,
|
||||
)
|
||||
|
||||
@@ -1089,83 +996,6 @@ class EplbState:
|
||||
model_state.new_logical_to_physical_map = None
|
||||
model_state.new_logical_replica_count = None
|
||||
|
||||
@staticmethod
|
||||
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
||||
"""
|
||||
Receive the expert load and old placement from the master rank.
|
||||
"""
|
||||
ep_group = get_ep_group()
|
||||
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
|
||||
num_models = num_models.item()
|
||||
global_expert_loads = []
|
||||
old_global_expert_indices_per_model = []
|
||||
for _ in range(num_models):
|
||||
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
|
||||
num_moe_layers, num_logical_experts, num_old_physical_experts = (
|
||||
metadata.tolist()
|
||||
)
|
||||
global_expert_load = torch.zeros(
|
||||
(num_moe_layers, num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
all_reduce(global_expert_load, group=ep_group.device_group)
|
||||
old_global_expert_indices = torch.empty(
|
||||
(num_moe_layers, num_old_physical_experts),
|
||||
dtype=torch.int64,
|
||||
device=ep_group.device,
|
||||
)
|
||||
torch.distributed.broadcast(
|
||||
old_global_expert_indices,
|
||||
group=ep_group.device_group,
|
||||
group_src=0,
|
||||
)
|
||||
global_expert_loads.append(global_expert_load)
|
||||
old_global_expert_indices_per_model.append(old_global_expert_indices)
|
||||
return global_expert_loads, old_global_expert_indices_per_model
|
||||
|
||||
@classmethod
|
||||
def get_eep_state(
|
||||
cls, parallel_config: ParallelConfig
|
||||
) -> tuple[
|
||||
list[torch.Tensor] | None,
|
||||
list[torch.Tensor] | None,
|
||||
dict[int, int] | None,
|
||||
]:
|
||||
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
|
||||
torch.distributed.broadcast(
|
||||
num_local_physical_experts,
|
||||
group=get_ep_group().cpu_group,
|
||||
group_src=0,
|
||||
)
|
||||
num_local_physical_experts = int(num_local_physical_experts.item())
|
||||
new_ep_size = get_ep_group().world_size
|
||||
global_expert_loads, old_global_expert_indices_per_model = (
|
||||
EplbState.recv_state()
|
||||
)
|
||||
|
||||
# EP configuration for all models has to be the same so as eplb config
|
||||
num_logical_experts = global_expert_loads[0].shape[1]
|
||||
parallel_config.eplb_config.num_redundant_experts = (
|
||||
num_local_physical_experts * new_ep_size - num_logical_experts
|
||||
)
|
||||
assert (
|
||||
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
|
||||
== 0
|
||||
)
|
||||
old_ep_size = (
|
||||
old_global_expert_indices_per_model[0].shape[1]
|
||||
// num_local_physical_experts
|
||||
)
|
||||
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
|
||||
return (
|
||||
global_expert_loads,
|
||||
old_global_expert_indices_per_model,
|
||||
rank_mapping,
|
||||
)
|
||||
|
||||
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""
|
||||
All-reduce a list of tensors.
|
||||
@@ -1203,6 +1033,60 @@ class EplbState:
|
||||
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
|
||||
return self._allreduce_list(load_pass_list)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(
|
||||
cls,
|
||||
model: MixtureOfExperts,
|
||||
model_config: ModelConfig,
|
||||
device: torch.device,
|
||||
parallel_config: ParallelConfig,
|
||||
expanded_physical_to_logical: torch.Tensor,
|
||||
num_valid_physical_experts: int,
|
||||
) -> "EplbState":
|
||||
eplb_state = cls(
|
||||
parallel_config=parallel_config,
|
||||
device=device,
|
||||
)
|
||||
eplb_state.add_model(
|
||||
model=model,
|
||||
model_config=model_config,
|
||||
)
|
||||
eplb_state.num_valid_physical_experts = num_valid_physical_experts
|
||||
num_moe_layers = expanded_physical_to_logical.shape[0]
|
||||
num_physical_experts = expanded_physical_to_logical.shape[1]
|
||||
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
|
||||
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
|
||||
|
||||
logical_to_physical_map = torch.full(
|
||||
(
|
||||
num_moe_layers,
|
||||
model.num_logical_experts,
|
||||
eplb_model_state.logical_to_physical_map.shape[2],
|
||||
),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
logical_replica_count = torch.zeros(
|
||||
(num_moe_layers, model.num_logical_experts),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
|
||||
for layer_idx in range(num_moe_layers):
|
||||
for phys_idx in range(num_physical_experts):
|
||||
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
|
||||
if logical_idx >= 0:
|
||||
replica_idx = logical_replica_count[layer_idx, logical_idx]
|
||||
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
|
||||
phys_idx
|
||||
)
|
||||
logical_replica_count[layer_idx, logical_idx] += 1
|
||||
|
||||
logical_to_physical_map = logical_to_physical_map.to(device)
|
||||
logical_replica_count = logical_replica_count.to(device)
|
||||
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
|
||||
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
|
||||
return eplb_state
|
||||
|
||||
|
||||
@dataclass
|
||||
class EplbLayerState:
|
||||
|
||||
@@ -19,6 +19,8 @@ from torch.distributed import (
|
||||
get_global_rank,
|
||||
)
|
||||
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -249,10 +251,18 @@ def move_to_buffer(
|
||||
b[dst].copy_(w[src_local], non_blocking=True)
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
if isinstance(get_ep_group(), StatelessGroupCoordinator):
|
||||
ep_group = get_ep_group()
|
||||
is_stateless = True
|
||||
else:
|
||||
is_stateless = False
|
||||
|
||||
# Pre-compute global ranks mapping
|
||||
# Pre-compute global ranks mapping (only needed for non-stateless groups)
|
||||
ep_size = ep_group.size()
|
||||
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
|
||||
if not is_stateless:
|
||||
rank_to_global = {
|
||||
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
|
||||
}
|
||||
|
||||
# 2. Post sends
|
||||
if send_count > 0:
|
||||
@@ -284,6 +294,14 @@ def move_to_buffer(
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
for dst in recv_ranks:
|
||||
if is_stateless:
|
||||
for w in expert_weights:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.isend
|
||||
op.tensor = w[src]
|
||||
op.group_peer = dst
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
@@ -321,6 +339,14 @@ def move_to_buffer(
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
if is_stateless:
|
||||
for b in expert_weights_buffers:
|
||||
op = object.__new__(P2POp)
|
||||
op.op = torch.distributed.irecv
|
||||
op.tensor = b[dst]
|
||||
op.group_peer = src
|
||||
p2p_ops.append(op)
|
||||
else:
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
@@ -334,10 +360,16 @@ def move_to_buffer(
|
||||
# 4. Execute the P2P operations. The real communication happens here.
|
||||
if p2p_ops and cuda_stream is not None:
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
elif p2p_ops:
|
||||
if is_stateless:
|
||||
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
|
||||
else:
|
||||
reqs = batch_isend_irecv(p2p_ops)
|
||||
for req in reqs:
|
||||
req.wait()
|
||||
|
||||
@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
|
||||
def clear_events(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
|
||||
self.add_events(other.get_all_events())
|
||||
return self
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches with data parallelism
|
||||
|
||||
@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
|
||||
"ExampleConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"ExampleHiddenStatesConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
|
||||
"ExampleHiddenStatesConnector",
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
|
||||
@@ -413,7 +413,20 @@ class TpKVTopology:
|
||||
f"by local tensor parallel size {self.tp_size}."
|
||||
)
|
||||
# P TP > D TP case, return the ratio as negative
|
||||
return -remote_tp_size // self.tp_size
|
||||
return remote_tp_size // self.tp_size
|
||||
|
||||
def pp_ratio(
|
||||
self,
|
||||
remote_pp_size: int,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate the pipeline parallel ratio between local and remote PP.
|
||||
"""
|
||||
assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, (
|
||||
f"Local pipline parallel size {self.tp_size} is not divisible "
|
||||
f"by remote pipline parallel size {remote_pp_size} or vice versa."
|
||||
)
|
||||
return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size
|
||||
|
||||
def block_size_ratio(
|
||||
self,
|
||||
@@ -457,6 +470,7 @@ class TpKVTopology:
|
||||
def get_target_remote_ranks(
|
||||
self,
|
||||
remote_tp_size: int,
|
||||
remote_pp_size: int
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the remote TP rank (on P) that the current local TP rank
|
||||
@@ -464,19 +478,36 @@ class TpKVTopology:
|
||||
read from multiple remote ranks.
|
||||
"""
|
||||
tp_ratio = self.tp_ratio(remote_tp_size)
|
||||
if tp_ratio > 0:
|
||||
return [self.tp_rank // tp_ratio]
|
||||
pp_ratio = self.pp_ratio(remote_pp_size)
|
||||
target_pp_rank_list = []
|
||||
target_tp_rank_list = []
|
||||
if self.pp_size < remote_pp_size:
|
||||
for i in range(pp_ratio):
|
||||
target_pp_rank_list.append(self.pp_rank * pp_ratio + i)
|
||||
else:
|
||||
target_pp_rank_list.append(self.pp_rank // pp_ratio)
|
||||
|
||||
# P TP > D TP case, D reads from |tp_ratio| remote workers.
|
||||
tp_ratio = -tp_ratio
|
||||
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
|
||||
if self.tp_size < remote_tp_size:
|
||||
for i in range(tp_ratio):
|
||||
target_tp_rank_list.append(self.tp_rank * tp_ratio + i)
|
||||
else:
|
||||
target_tp_rank_list.append(self.tp_rank // tp_ratio)
|
||||
|
||||
target_rank_list = []
|
||||
for pp_rank in target_pp_rank_list:
|
||||
for tp_rank in target_tp_rank_list:
|
||||
target_rank = pp_rank * remote_tp_size + tp_rank
|
||||
target_rank_list.append((target_rank, pp_rank, tp_rank))
|
||||
|
||||
return target_rank_list
|
||||
|
||||
def get_target_remote_ranks_from_engine_id(
|
||||
self,
|
||||
remote_engine_id: EngineId,
|
||||
) -> list[int]:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
remote_pp_size = self.remote_pp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size, remote_pp_size)
|
||||
|
||||
|
||||
def get_current_attn_backend(vllm_config: VllmConfig):
|
||||
|
||||
@@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC):
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if this connector requires PIECEWISE CUDA graph mode.
|
||||
|
||||
Connectors that use asynchronous layer-by-layer operations
|
||||
(wait_for_layer_load/save_kv_layer) should override this method
|
||||
to return True when those operations are enabled. These operations
|
||||
cannot be captured in CUDA graphs and will be skipped during replay,
|
||||
causing data races. PIECEWISE mode allows Python code to execute
|
||||
between graph pieces, ensuring proper synchronization.
|
||||
|
||||
Args:
|
||||
extra_config: The kv_connector_extra_config dict from
|
||||
KVTransferConfig.
|
||||
|
||||
Returns:
|
||||
True if this connector requires PIECEWISE CUDA graph mode,
|
||||
False otherwise.
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_finished_count(self) -> int | None:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
num_pages * page_size, -1
|
||||
)
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
elif isinstance(attn_metadata, TritonAttentionMetadata):
|
||||
block_idxs = slot_mapping // self._block_size
|
||||
offsets = slot_mapping % self._block_size
|
||||
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
layer_name, request.token_ids, request.mm_hashes
|
||||
)
|
||||
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
|
||||
if isinstance(attn_metadata, dict):
|
||||
inject_kv_into_layer(
|
||||
kv_cache_layer,
|
||||
kv_cache,
|
||||
request.slot_mapping,
|
||||
attn_metadata[layer_name],
|
||||
)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
elif isinstance(attn_metadata, TritonAttentionMetadata):
|
||||
block_idxs = slot_mapping // self._block_size
|
||||
offsets = slot_mapping % self._block_size
|
||||
return layer[block_idxs, :, offsets]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
|
||||
|
||||
|
||||
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def extract_from_kv_cache(
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""Extract data from KV cache
|
||||
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
|
||||
"""
|
||||
|
||||
padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
|
||||
# shape: [len(slot_mapping), num_heads, head_size]
|
||||
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request ID
|
||||
req_id: str
|
||||
# Request filename
|
||||
filename: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
# Whether this request is a new request or partially computed already
|
||||
new_req: bool
|
||||
|
||||
@staticmethod
|
||||
def make_meta(
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool,
|
||||
) -> "ReqMeta":
|
||||
token_ids_tensor = torch.tensor(token_ids)
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape((1, block_size))
|
||||
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
return ReqMeta(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
new_req=new_req,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta] = field(default_factory=list)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_id: str,
|
||||
filename: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
new_req: bool = True,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(
|
||||
req_id, filename, token_ids, block_ids, block_size, new_req
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExampleHiddenStatesConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
Simple debug implementation of a HiddenStatesConnector.
|
||||
|
||||
Simply extracts the hidden states from the kv cache and stores them to disk.
|
||||
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
|
||||
"""
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
"""
|
||||
Indicates whether this connector prefers KV blocks that hold KV data for all
|
||||
layers, which can speed up KV data transfers. Defaults to False.
|
||||
"""
|
||||
# Must be False so that drafter kv cache isn't merged with verifier's
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
role=role,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||
"shared_storage_path", "/tmp"
|
||||
)
|
||||
self.cache_layers: list[str] = [] # set by self.register_kv_caches
|
||||
logger.info(self._kv_transfer_config)
|
||||
logger.info("Shared storage path is %s", self._storage_path)
|
||||
|
||||
assert self._vllm_config.speculative_config is not None, (
|
||||
"ExampleHiddenStatesConnector only works when using "
|
||||
"'extract_hidden_states' speculative method"
|
||||
)
|
||||
spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
|
||||
self.num_hidden_states = len(
|
||||
getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
|
||||
)
|
||||
|
||||
self._request_filenames: dict[str, str] = {}
|
||||
self._active_requests: dict[str, NewRequestData] = {}
|
||||
self._req_blocks: dict[str, list[int]] = {}
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def start_load_kv(self, *args, **kwargs: Any) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def wait_for_save(self):
|
||||
pass # Empty implementation of abstract method
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionLayer,
|
||||
)
|
||||
|
||||
# Filter layers to only include CacheOnlyAttentionLayers
|
||||
layers = get_layers_from_vllm_config(
|
||||
self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
|
||||
)
|
||||
self.cache_layers = list(layers.keys())
|
||||
assert len(self.cache_layers) == 1, (
|
||||
f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
|
||||
)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
if layer_name not in self.cache_layers:
|
||||
return
|
||||
|
||||
from vllm.model_executor.models.extract_hidden_states import (
|
||||
CacheOnlyAttentionMetadata,
|
||||
)
|
||||
|
||||
assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
|
||||
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
|
||||
)
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
|
||||
|
||||
os.makedirs(self._storage_path, exist_ok=True)
|
||||
for request in connector_metadata.requests:
|
||||
hidden_states = extract_from_kv_cache(
|
||||
kv_layer, request.slot_mapping, request.token_ids.shape[0]
|
||||
)
|
||||
tensors = {
|
||||
"hidden_states": hidden_states.detach().cpu(),
|
||||
"token_ids": request.token_ids.detach().cpu(),
|
||||
}
|
||||
safetensors.torch.save_file(tensors, request.filename)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
# This connector is store-only, so we don't need to load any tokens
|
||||
return 0, False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
# Usually used to handle allocation of new blocks for requests that are loading
|
||||
# tokens from connector's external kv cache. We never load from external cache
|
||||
# so this is a no-op.
|
||||
assert num_external_tokens == 0, "This connector is store-only"
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
meta = ExampleHiddenStatesConnectorMetadata()
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
token_ids = new_req.prompt_token_ids or []
|
||||
filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
|
||||
meta.add_request(
|
||||
new_req.req_id,
|
||||
filename=filename,
|
||||
token_ids=token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
)
|
||||
self._request_filenames[new_req.req_id] = filename
|
||||
self._active_requests[new_req.req_id] = new_req
|
||||
self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
if req_id not in self._active_requests:
|
||||
continue
|
||||
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
|
||||
cached_req = self._active_requests[req_id]
|
||||
req_block_ids = self._req_blocks[req_id]
|
||||
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
req_block_ids.extend(block_ids)
|
||||
filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
|
||||
|
||||
meta.add_request(
|
||||
req_id=req_id,
|
||||
filename=filename,
|
||||
token_ids=cached_req.prompt_token_ids or [],
|
||||
block_ids=req_block_ids,
|
||||
block_size=self._block_size,
|
||||
new_req=False,
|
||||
)
|
||||
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished, before its blocks are
|
||||
freed.
|
||||
|
||||
The connector may assumes responsibility for freeing the blocks
|
||||
asynchronously by returning True.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
req_filename = self._request_filenames.pop(req_id, None)
|
||||
_ = self._active_requests.pop(req_id, None)
|
||||
_ = self._req_blocks.pop(req_id, None)
|
||||
|
||||
return False, {"hidden_states_path": req_filename}
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
|
||||
if cls is KVConnectorBase_V1:
|
||||
raise TypeError(
|
||||
"get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class"
|
||||
)
|
||||
# NHD means we have (num_tokens, num_heads)
|
||||
# HND means we have (num_heads, num_tokens)
|
||||
# For now, we only support NHD layout since this keeps the
|
||||
# hidden states for each token together in memory.
|
||||
# HND is primarily used when sharding heads across devices.
|
||||
return "NHD"
|
||||
@@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
LMCache requires PIECEWISE CUDA graph mode when layerwise
|
||||
operations are enabled. The wait_for_layer_load and save_kv_layer
|
||||
methods perform actual async synchronization that cannot be
|
||||
captured in CUDA graphs.
|
||||
"""
|
||||
return extra_config.get("use_layerwise", False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
|
||||
@@ -173,6 +173,29 @@ class MooncakeConnector(KVConnectorBase_V1):
|
||||
self.connector_scheduler = None
|
||||
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
|
||||
|
||||
|
||||
############################################################
|
||||
# Class Methods
|
||||
############################################################
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
|
||||
if vllm_config.model_config is None:
|
||||
logger.warning_once(
|
||||
"Unable to detect current VLLM config. "
|
||||
"Fallback to default kv cache layout."
|
||||
)
|
||||
return None
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if use_mla:
|
||||
# return None when we have mla
|
||||
# as the layout should not matter in that case,
|
||||
# which fallback to the default behavior.
|
||||
return None
|
||||
logger.info_once(
|
||||
"MooncakeConnector setting KV cache layout to HND for better xfer performance."
|
||||
)
|
||||
return "HND"
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
@@ -941,7 +964,13 @@ class MooncakeConnectorWorker:
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, (
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
|
||||
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "HND":
|
||||
kernel_block_size = cache.shape[-2]
|
||||
else:
|
||||
kernel_block_size = cache.shape[-3]
|
||||
|
||||
assert self.block_size == kernel_block_size
|
||||
kv_data_ptrs.append(base_addr)
|
||||
kv_data_lens.append(tensor_size_bytes)
|
||||
|
||||
@@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
- Save to all connectors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
MultiConnector requires PIECEWISE CUDA graph mode if any of its
|
||||
child connectors require it.
|
||||
"""
|
||||
connectors_config = extra_config.get("connectors", [])
|
||||
for conn_config in connectors_config:
|
||||
temp_ktc = KVTransferConfig(**conn_config)
|
||||
connector_cls = KVConnectorFactory.get_connector_class(temp_ktc)
|
||||
child_extra_config = conn_config.get("kv_connector_extra_config", {})
|
||||
if connector_cls.requires_piecewise_for_cudagraph(child_extra_config):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import pickle
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
@@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Protocol
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
import ixformer.distributed as ixfd
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
@@ -327,6 +332,8 @@ class GroupCoordinator:
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.local_rank = local_rank
|
||||
|
||||
use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"}
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
|
||||
@@ -339,7 +346,7 @@ class GroupCoordinator:
|
||||
with suppress_stdout():
|
||||
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
|
||||
if self.rank in ranks:
|
||||
self.ixfd_group = ixfd.init_comm_with_store(device_group)
|
||||
self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
@@ -372,8 +379,7 @@ class GroupCoordinator:
|
||||
self.device_communicator = device_comm_cls(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
# device_group=self.device_group,
|
||||
device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group,
|
||||
device_group=self.ixfd_group if use_vllm_comm else self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
)
|
||||
|
||||
@@ -385,11 +391,6 @@ class GroupCoordinator:
|
||||
self.cpu_group, 1 << 22, 6
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# self.use_custom_op_call = (
|
||||
# current_platform.is_cuda_alike() or current_platform.is_tpu()
|
||||
# )
|
||||
self.use_custom_op_call = False
|
||||
|
||||
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
|
||||
@@ -468,14 +469,12 @@ class GroupCoordinator:
|
||||
# only cuda uses this function,
|
||||
# so we don't abstract it into the base class
|
||||
maybe_ca_context = nullcontext()
|
||||
# from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
# CudaCommunicator,
|
||||
# )
|
||||
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator,
|
||||
)
|
||||
|
||||
if self.device_communicator is not None:
|
||||
# assert isinstance(self.device_communicator, CudaCommunicator)
|
||||
assert isinstance(self.device_communicator, DeviceCommunicatorBase)
|
||||
assert isinstance(self.device_communicator, CudaCommunicator)
|
||||
ca_comm = self.device_communicator.ca_comm
|
||||
if ca_comm is not None:
|
||||
maybe_ca_context = ca_comm.capture() # type: ignore
|
||||
@@ -608,9 +607,9 @@ class GroupCoordinator:
|
||||
src=self.ranks[src],
|
||||
group=self.device_group)
|
||||
else:
|
||||
torch.distributed.broadcast(input_,
|
||||
src=self.ranks[src],
|
||||
group=self.device_group)
|
||||
torch.distributed.broadcast(
|
||||
input_, src=self.ranks[src], group=self.device_group
|
||||
)
|
||||
return input_
|
||||
|
||||
def broadcast_object(self, obj: Any | None = None, src: int = 0):
|
||||
@@ -764,10 +763,9 @@ class GroupCoordinator:
|
||||
group=group,
|
||||
async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor, src=self.ranks[src], group=group, async_op=True
|
||||
)
|
||||
async_handles.append(handle)
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
@@ -802,10 +800,8 @@ class GroupCoordinator:
|
||||
async_op=True)
|
||||
else:
|
||||
handle = torch.distributed.broadcast(
|
||||
tensor,
|
||||
src=self.ranks[src],
|
||||
group=group,
|
||||
async_op=True)
|
||||
tensor, src=self.ranks[src], group=group, async_op=True
|
||||
)
|
||||
async_handles.append(handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
@@ -876,6 +872,10 @@ class GroupCoordinator:
|
||||
if self.world_size <= 1:
|
||||
return []
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
@@ -893,10 +893,6 @@ class GroupCoordinator:
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
||||
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
self.send_object(metadata_list, dst=dst)
|
||||
|
||||
@@ -917,6 +913,7 @@ class GroupCoordinator:
|
||||
handle = torch.distributed.isend(
|
||||
tensor, dst=self.ranks[dst], group=comm_group
|
||||
)
|
||||
|
||||
if tensor.is_cuda:
|
||||
tensor.record_stream(torch.cuda.current_stream(tensor.device))
|
||||
handles.append(handle)
|
||||
@@ -973,6 +970,11 @@ class GroupCoordinator:
|
||||
]:
|
||||
if not torch.distributed.is_initialized() or self.world_size == 1:
|
||||
return None, [], []
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
if self.use_cpu_custom_send_recv:
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
@@ -990,10 +992,6 @@ class GroupCoordinator:
|
||||
group = self.device_group
|
||||
metadata_group = self.cpu_group
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size, f"Invalid src rank ({src})"
|
||||
|
||||
recv_metadata_list = self.recv_object(src=src)
|
||||
tensor_dict: dict[str, Any] = {}
|
||||
handles: list[Handle] = []
|
||||
@@ -1072,14 +1070,13 @@ class GroupCoordinator:
|
||||
return self.device_communicator.recv(size, dtype, src)
|
||||
|
||||
def destroy(self):
|
||||
if hasattr(self, "device_group"):
|
||||
# torch.distributed.destroy_process_group(self.device_group)
|
||||
if self.device_group is not None:
|
||||
if self.device_communicator and self.device_communicator.use_vllm_comm:
|
||||
ixfd.destroy_process_group(self.device_group)
|
||||
else:
|
||||
torch.distributed.destroy_process_group(self.device_group)
|
||||
del self.device_group
|
||||
if hasattr(self, "cpu_group"):
|
||||
self.device_group = None
|
||||
if self.cpu_group is not None:
|
||||
torch.distributed.destroy_process_group(self.cpu_group)
|
||||
del self.cpu_group
|
||||
if self.device_communicator is not None:
|
||||
@@ -1094,7 +1091,6 @@ class GroupCoordinator:
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
extra_residual:torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
@@ -1105,13 +1101,12 @@ class GroupCoordinator:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch_router_logits(
|
||||
hidden_states,
|
||||
extra_residual,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
else:
|
||||
return hidden_states, extra_residual, router_logits
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
@@ -1189,6 +1184,55 @@ def init_model_parallel_group(
|
||||
)
|
||||
|
||||
|
||||
def _init_stateless_group(
|
||||
group_ranks: list[list[int]],
|
||||
group_name: str,
|
||||
group_ports: list[list[int]],
|
||||
host: str,
|
||||
backend: str,
|
||||
use_device_communicator: bool = True,
|
||||
) -> "StatelessGroupCoordinator":
|
||||
"""Create a StatelessGroupCoordinator with the given parameters."""
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
world = get_world_group()
|
||||
return StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=world.local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=use_device_communicator,
|
||||
group_name=group_name,
|
||||
host=host,
|
||||
group_ports=group_ports,
|
||||
global_rank=world.rank,
|
||||
global_world_size=world.world_size,
|
||||
)
|
||||
|
||||
|
||||
def _replace_active_groups(
|
||||
*,
|
||||
world: GroupCoordinator | None,
|
||||
dp: GroupCoordinator | None,
|
||||
ep: GroupCoordinator | None,
|
||||
eplb: GroupCoordinator | None,
|
||||
node_count: int | None,
|
||||
) -> None:
|
||||
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
|
||||
|
||||
Destruction is collective — all ranks in the old groups must call this
|
||||
function together. Pass all-``None`` to tear down without replacement.
|
||||
"""
|
||||
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
|
||||
for group in (_DP, _EP, _WORLD, _EPLB):
|
||||
if group is not None:
|
||||
group.destroy()
|
||||
_WORLD = world
|
||||
_DP = dp
|
||||
_EP = ep
|
||||
_EPLB = eplb
|
||||
_NODE_COUNT = node_count
|
||||
|
||||
|
||||
_TP: GroupCoordinator | None = None
|
||||
|
||||
|
||||
@@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool):
|
||||
_ENABLE_CUSTOM_ALL_REDUCE = enable
|
||||
|
||||
|
||||
def _init_elastic_ep_world(
|
||||
config, local_rank: int, backend: str, rank: int, world_size: int
|
||||
) -> None:
|
||||
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
|
||||
|
||||
global _WORLD, _NODE_COUNT
|
||||
assert _WORLD is None, "world group already initialized"
|
||||
parallel_config = config.parallel_config
|
||||
global_rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
global_world_size = parallel_config.world_size_across_dp
|
||||
all_ranks = list(range(global_world_size))
|
||||
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
|
||||
if global_rank in all_ranks:
|
||||
group_ranks = [all_ranks]
|
||||
group_ports = [parallel_config.get_next_stateless_world_group_port()]
|
||||
world = StatelessGroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=False,
|
||||
group_name="world",
|
||||
host=parallel_config.data_parallel_master_ip,
|
||||
group_ports=group_ports,
|
||||
global_rank=global_rank,
|
||||
global_world_size=global_world_size,
|
||||
)
|
||||
assert parallel_config.nnodes_within_dp == 1, (
|
||||
"Elastic EP is not supported with multi-node TP/PP"
|
||||
)
|
||||
_NODE_COUNT = _node_count(world.tcp_store_group)
|
||||
_WORLD = world
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
@@ -1305,6 +1382,7 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
|
||||
if (
|
||||
config is not None
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
@@ -1312,6 +1390,7 @@ def init_distributed_environment(
|
||||
config.parallel_config.nnodes > 1
|
||||
or config.parallel_config.data_parallel_size > 1
|
||||
)
|
||||
and not enable_elastic_ep
|
||||
):
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
@@ -1365,6 +1444,18 @@ def init_distributed_environment(
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
if enable_elastic_ep:
|
||||
tp_pp_cpu_group = torch.distributed.new_group(
|
||||
backend="gloo", timeout=timeout
|
||||
)
|
||||
if _node_count(tp_pp_cpu_group) > 1:
|
||||
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
|
||||
# to initialize all DP/EP groups, hence all ranks within TP/PP group
|
||||
# must reside on the same node
|
||||
raise RuntimeError(
|
||||
"Elastic EP is not yet supported with multi-node TP/PP"
|
||||
)
|
||||
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
@@ -1373,6 +1464,9 @@ def init_distributed_environment(
|
||||
# setting, where we can use rank as local rank
|
||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
|
||||
if enable_elastic_ep:
|
||||
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
|
||||
return
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
@@ -1436,16 +1530,33 @@ def initialize_model_parallel(
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
config = get_current_vllm_config()
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
enable_elastic_ep = config.parallel_config.enable_elastic_ep
|
||||
if enable_elastic_ep:
|
||||
# Use stateless world group for global information
|
||||
world_size = get_world_group().world_size
|
||||
rank = get_world_group().rank
|
||||
backend = backend or "nccl"
|
||||
tp_pp_pcp_size = (
|
||||
tensor_model_parallel_size
|
||||
* pipeline_model_parallel_size
|
||||
* prefill_context_model_parallel_size
|
||||
)
|
||||
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
|
||||
pipeline_model_parallel_size,
|
||||
prefill_context_model_parallel_size,
|
||||
tensor_model_parallel_size,
|
||||
)
|
||||
else:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group
|
||||
)
|
||||
|
||||
# the layout order is: ExternalDP x DP x PP x TP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
@@ -1469,7 +1580,9 @@ def initialize_model_parallel(
|
||||
assert _TP is None, "tensor model parallel group is already initialized"
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_TP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
@@ -1488,6 +1601,11 @@ def initialize_model_parallel(
|
||||
# TP group into tp_size//dcp_size DCP groups.
|
||||
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = local_all_ranks.reshape(
|
||||
-1, decode_context_model_parallel_size
|
||||
).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DCP = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
@@ -1504,6 +1622,13 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(1, 2)
|
||||
.reshape(-1, prefill_context_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PCP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
|
||||
)
|
||||
@@ -1515,6 +1640,13 @@ def initialize_model_parallel(
|
||||
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
group_ranks = (
|
||||
local_all_ranks.transpose(0, 2)
|
||||
.reshape(-1, pipeline_model_parallel_size)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="pp"
|
||||
)
|
||||
@@ -1523,6 +1655,19 @@ def initialize_model_parallel(
|
||||
assert _DP is None, "data parallel group is already initialized"
|
||||
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
dp_ports = [
|
||||
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
|
||||
]
|
||||
_DP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"dp",
|
||||
dp_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_DP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dp"
|
||||
)
|
||||
@@ -1530,7 +1675,7 @@ def initialize_model_parallel(
|
||||
global _EP
|
||||
assert _EP is None, "expert parallel group is already initialized"
|
||||
# Don't create EP group for dense models.
|
||||
if config is None or config.model_config is None or config.model_config.is_moe:
|
||||
if config.model_config is None or config.model_config.is_moe:
|
||||
group_ranks = (
|
||||
all_ranks.transpose(1, 2)
|
||||
.reshape(
|
||||
@@ -1542,6 +1687,19 @@ def initialize_model_parallel(
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
if enable_elastic_ep:
|
||||
parallel_config = config.parallel_config
|
||||
ep_ports = [
|
||||
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
|
||||
]
|
||||
_EP = _init_stateless_group(
|
||||
group_ranks,
|
||||
"ep",
|
||||
ep_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EP = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="ep"
|
||||
)
|
||||
@@ -1557,9 +1715,24 @@ def initialize_model_parallel(
|
||||
and config.parallel_config is not None
|
||||
and config.parallel_config.enable_eplb
|
||||
):
|
||||
# Reuse the same group_ranks from EP
|
||||
if enable_elastic_ep:
|
||||
eplb_ports = [
|
||||
parallel_config.get_next_stateless_eplb_group_port()
|
||||
for _ in group_ranks
|
||||
]
|
||||
_EPLB = _init_stateless_group(
|
||||
group_ranks,
|
||||
"eplb",
|
||||
eplb_ports,
|
||||
parallel_config.data_parallel_master_ip,
|
||||
backend,
|
||||
)
|
||||
else:
|
||||
_EPLB = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="eplb",
|
||||
)
|
||||
# If no EP group needed, _EP remains None
|
||||
# If no EPLB group needed, _EPLB remains None
|
||||
@@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized(
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
|
||||
world_group = get_world_group()
|
||||
if hasattr(world_group, "backend"):
|
||||
backend = backend or world_group.backend
|
||||
else:
|
||||
backend = backend or torch.distributed.get_backend(world_group.device_group)
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size,
|
||||
|
||||
322
vllm/distributed/stateless_coordinator.py
Normal file
322
vllm/distributed/stateless_coordinator.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import Backend, ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.parallel_state import (
|
||||
GroupCoordinator,
|
||||
TensorMetadata,
|
||||
_get_unique_name,
|
||||
_register_group,
|
||||
_split_tensor_dict,
|
||||
)
|
||||
from vllm.distributed.utils import (
|
||||
StatelessProcessGroup,
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
stateless_init_torch_distributed_process_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StatelessGroupCoordinator(GroupCoordinator):
|
||||
"""
|
||||
A stateless version of the GroupCoordinator class in parallel_state,
|
||||
It will create CPU, device and TCPStore based communication groups
|
||||
that are independent of PyTorch's WORLD group. Hence,
|
||||
communication groups with a different set of participants GPUs
|
||||
can be created without destroying the existing ones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_ranks: list[list[int]],
|
||||
local_rank: int,
|
||||
torch_distributed_backend: str | Backend,
|
||||
use_device_communicator: bool,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: str | None = None,
|
||||
host: str = "127.0.0.1",
|
||||
group_ports: list[list[int]] | None = None,
|
||||
global_rank: int = 0,
|
||||
global_world_size: int = 1,
|
||||
):
|
||||
group_name = group_name or "anonymous"
|
||||
self.unique_name = _get_unique_name(group_name)
|
||||
_register_group(self)
|
||||
|
||||
self.rank = global_rank
|
||||
self.local_rank = local_rank
|
||||
|
||||
self_device_group = None
|
||||
self_cpu_group = None
|
||||
self_tcp_store_group = None
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
backend = str(torch_distributed_backend)
|
||||
self.backend = backend
|
||||
assert group_ports is not None, "group_ports is not provided"
|
||||
for idx, ranks in enumerate(group_ranks):
|
||||
if self.rank in ranks:
|
||||
self.ranks = ranks
|
||||
self.world_size = len(ranks)
|
||||
self.rank_in_group = ranks.index(self.rank)
|
||||
|
||||
ports = group_ports[idx]
|
||||
device_port = ports[0]
|
||||
cpu_port = ports[1]
|
||||
tcp_store_port = ports[2]
|
||||
|
||||
device_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
port=device_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
backend=backend,
|
||||
group_name=f"{self.unique_name}_device",
|
||||
)
|
||||
cpu_group = stateless_init_torch_distributed_process_group(
|
||||
host=host,
|
||||
port=cpu_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
backend="gloo",
|
||||
group_name=f"{self.unique_name}_cpu",
|
||||
)
|
||||
tcp_store_group = StatelessProcessGroup.create(
|
||||
host=host,
|
||||
port=tcp_store_port,
|
||||
rank=self.rank_in_group,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
self_device_group = device_group
|
||||
self_cpu_group = cpu_group
|
||||
self_tcp_store_group = tcp_store_group
|
||||
|
||||
assert self_cpu_group is not None
|
||||
assert self_device_group is not None
|
||||
assert self_tcp_store_group is not None
|
||||
|
||||
self.cpu_group = self_cpu_group
|
||||
self.device_group = self_device_group
|
||||
self.tcp_store_group = self_tcp_store_group
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
self.device = torch.device(f"cuda:{local_rank}")
|
||||
elif current_platform.is_xpu():
|
||||
self.device = torch.device(f"xpu:{local_rank}")
|
||||
elif current_platform.is_out_of_tree():
|
||||
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
self.use_device_communicator = use_device_communicator
|
||||
self.device_communicator = None
|
||||
if use_device_communicator and self.world_size > 1:
|
||||
device_comm_cls = resolve_obj_by_qualname(
|
||||
current_platform.get_device_communicator_cls()
|
||||
)
|
||||
assert device_comm_cls == CudaCommunicator
|
||||
self.device_communicator = CudaCommunicator(
|
||||
cpu_group=self.cpu_group,
|
||||
device=self.device,
|
||||
device_group=self.device_group,
|
||||
unique_name=self.unique_name,
|
||||
global_ranks=self.ranks,
|
||||
global_world_size=global_world_size,
|
||||
tcp_store_group=self.tcp_store_group,
|
||||
)
|
||||
|
||||
self.mq_broadcaster = None
|
||||
|
||||
self.use_custom_op_call = (
|
||||
current_platform.is_cuda_alike() or current_platform.is_tpu()
|
||||
)
|
||||
self.use_cpu_custom_send_recv = False
|
||||
|
||||
def destroy(self):
|
||||
if self.device_communicator:
|
||||
self.device_communicator.destroy()
|
||||
if self.device_group:
|
||||
stateless_destroy_torch_distributed_process_group(self.device_group)
|
||||
if self.cpu_group:
|
||||
stateless_destroy_torch_distributed_process_group(self.cpu_group)
|
||||
|
||||
def size(self) -> int:
|
||||
"""Return the world size of this group."""
|
||||
return self.world_size
|
||||
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0):
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.device_communicator and input_.is_cuda:
|
||||
return self.device_communicator.broadcast(input_, src)
|
||||
else:
|
||||
return self.tcp_store_group.broadcast(input_, src)
|
||||
|
||||
def broadcast_object(self, obj=None, src: int = 0):
|
||||
if self.world_size == 1:
|
||||
return obj
|
||||
return self.tcp_store_group.broadcast_obj(obj, src)
|
||||
|
||||
def broadcast_object_list(
|
||||
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
|
||||
):
|
||||
assert src < self.world_size
|
||||
|
||||
if self.world_size == 1:
|
||||
return obj_list
|
||||
|
||||
if self.rank_in_group == src:
|
||||
for obj in obj_list:
|
||||
self.tcp_store_group.broadcast_obj(obj, src)
|
||||
else:
|
||||
for i in range(len(obj_list)):
|
||||
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
|
||||
|
||||
return obj_list
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
|
||||
src: int = 0,
|
||||
group: ProcessGroup | None = None,
|
||||
metadata_group: ProcessGroup | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
if self.rank_in_group == src:
|
||||
assert isinstance(tensor_dict, dict), (
|
||||
f"Expecting a dictionary, got {type(tensor_dict)}"
|
||||
)
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
else:
|
||||
metadata_list = None
|
||||
tensor_list = []
|
||||
|
||||
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
|
||||
metadata_list, src
|
||||
)
|
||||
|
||||
if self.rank_in_group != src:
|
||||
tensor_dict = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(
|
||||
value.size, dtype=value.dtype, device=value.device
|
||||
)
|
||||
tensor_list.append(tensor)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
continue
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
tensor.copy_(self.device_communicator.broadcast(tensor, src))
|
||||
else:
|
||||
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
|
||||
|
||||
return tensor_dict
|
||||
|
||||
def send_object(self, obj, dst: int) -> None:
|
||||
assert dst < self.world_size
|
||||
assert dst != self.rank_in_group
|
||||
self.tcp_store_group.send_obj(obj, dst)
|
||||
|
||||
def recv_object(self, src: int):
|
||||
assert src < self.world_size
|
||||
assert src != self.rank_in_group
|
||||
return self.tcp_store_group.recv_obj(src)
|
||||
|
||||
def send_tensor_dict(
|
||||
self,
|
||||
tensor_dict: dict[str, torch.Tensor | Any],
|
||||
dst: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
assert dst < self.world_size
|
||||
|
||||
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||
self.tcp_store_group.send_obj(metadata_list, dst)
|
||||
|
||||
for tensor in tensor_list:
|
||||
if tensor.numel() == 0:
|
||||
continue
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
self.device_communicator.send(tensor, dst)
|
||||
else:
|
||||
self.tcp_store_group.send(tensor, dst)
|
||||
|
||||
return None
|
||||
|
||||
def recv_tensor_dict(
|
||||
self,
|
||||
src: int | None = None,
|
||||
all_gather_group: Optional["GroupCoordinator"] = None,
|
||||
all_gather_tensors: dict[str, bool] | None = None,
|
||||
) -> dict[str, torch.Tensor | Any] | None:
|
||||
if self.world_size == 1:
|
||||
return None
|
||||
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
assert src < self.world_size
|
||||
|
||||
recv_metadata_list = self.tcp_store_group.recv_obj(src)
|
||||
tensor_dict = {}
|
||||
for key, value in recv_metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
|
||||
if tensor.numel() > 0:
|
||||
if self.device_communicator and tensor.is_cuda:
|
||||
tensor = self.device_communicator.recv(
|
||||
tensor.size(), tensor.dtype, src
|
||||
)
|
||||
else:
|
||||
tensor = self.tcp_store_group.recv(tensor, src)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
return tensor_dict
|
||||
|
||||
def barrier(self):
|
||||
self.tcp_store_group.barrier()
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
|
||||
if self.device_communicator is None:
|
||||
raise ValueError("No device communicator found")
|
||||
|
||||
if self.rank_in_group == dst:
|
||||
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
|
||||
gathered_list[self.rank_in_group] = input_
|
||||
for src_rank in range(self.world_size):
|
||||
if src_rank != self.rank_in_group:
|
||||
gathered_list[src_rank] = self.device_communicator.recv(
|
||||
input_.size(), input_.dtype, src_rank
|
||||
)
|
||||
return torch.cat(gathered_list, dim=dim)
|
||||
else:
|
||||
self.device_communicator.send(input_, dst)
|
||||
return None
|
||||
@@ -18,7 +18,7 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed import ProcessGroup, Store, TCPStore
|
||||
from torch.distributed.distributed_c10d import (
|
||||
Backend,
|
||||
PrefixStore,
|
||||
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
|
||||
gathered_objs.append(recv_obj)
|
||||
return gathered_objs
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Broadcast a tensor from source rank to all other ranks."""
|
||||
if self.rank == src:
|
||||
tensor_bytes = pickle.dumps(tensor)
|
||||
self.expire_data()
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
|
||||
self.store.set(key, tensor_bytes)
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return tensor
|
||||
else:
|
||||
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
|
||||
tensor = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
return tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int):
|
||||
"""Send a tensor to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(tensor))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
|
||||
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
|
||||
"""Receive a tensor from a source rank."""
|
||||
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
|
||||
received = pickle.loads(self.store.get(key))
|
||||
self.recv_src_counter[src] += 1
|
||||
tensor.copy_(received)
|
||||
return tensor
|
||||
|
||||
def all_reduce(
|
||||
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
|
||||
) -> torch.Tensor:
|
||||
"""All-reduce a tensor across all ranks."""
|
||||
tensors = self.all_gather_obj(tensor)
|
||||
result = tensors[0].clone()
|
||||
for t in tensors[1:]:
|
||||
if op == torch.distributed.ReduceOp.SUM:
|
||||
result.add_(t)
|
||||
elif op == torch.distributed.ReduceOp.PRODUCT:
|
||||
result.mul_(t)
|
||||
elif op == torch.distributed.ReduceOp.MAX:
|
||||
result = torch.maximum(result, t)
|
||||
elif op == torch.distributed.ReduceOp.MIN:
|
||||
result = torch.minimum(result, t)
|
||||
return result
|
||||
|
||||
def barrier(self, timeout: float = 30.0):
|
||||
"""A robust barrier to synchronize all ranks.
|
||||
|
||||
@@ -448,8 +497,14 @@ def init_gloo_process_group(
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int, backend: str
|
||||
) -> ProcessGroup:
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
group_name: str | None = None,
|
||||
return_store: bool = False,
|
||||
) -> ProcessGroup | tuple[ProcessGroup, Store]:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
try:
|
||||
|
||||
if backend == "gloo":
|
||||
pg = init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return current_platform.stateless_init_device_torch_dist_pg(
|
||||
pg = current_platform.stateless_init_device_torch_dist_pg(
|
||||
backend=backend,
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
|
||||
# will raise a NotImplementedError. In this case, we fall back to gloo.
|
||||
return init_gloo_process_group(
|
||||
prefix_store=prefix_store,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if group_name is not None:
|
||||
from torch._C._distributed_c10d import _register_process_group
|
||||
|
||||
pg._set_group_name(group_name)
|
||||
_register_process_group(group_name, pg)
|
||||
|
||||
if return_store:
|
||||
return pg, store
|
||||
else:
|
||||
return pg
|
||||
|
||||
|
||||
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Base class for weight transfer engines."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import KW_ONLY, dataclass, field
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
|
||||
This should be called when the worker is shutting down.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | Any,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers.
|
||||
|
||||
This is a static method that can be called from the trainer process
|
||||
to send weights to all inference workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
The tensors should be on the appropriate device for the backend.
|
||||
trainer_args: Dictionary containing backend-specific arguments needed
|
||||
to send weights. The structure depends on the backend:
|
||||
- NCCL: Contains 'group', 'src', 'packed', etc.
|
||||
- IPC: Contains 'mode' ('http' or 'ray'),
|
||||
'llm_handle' (for Ray), 'url' (for HTTP), etc.
|
||||
|
||||
Example:
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> engine.trainer_send_weights(param_iter, trainer_args)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine(
|
||||
"vllm.distributed.weight_transfer.nccl_engine",
|
||||
"NCCLWeightTransferEngine",
|
||||
)
|
||||
|
||||
WeightTransferEngineFactory.register_engine(
|
||||
"ipc",
|
||||
"vllm.distributed.weight_transfer.ipc_engine",
|
||||
"IPCWeightTransferEngine",
|
||||
)
|
||||
|
||||
291
vllm/distributed/weight_transfer/ipc_engine.py
Normal file
291
vllm/distributed/weight_transfer/ipc_engine.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""IPC-based weight transfer engine using CUDA IPC for communication."""
|
||||
|
||||
import base64
|
||||
import pickle
|
||||
from collections.abc import Callable, Iterator
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import reduce_tensor
|
||||
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferEngine,
|
||||
WeightTransferInitInfo,
|
||||
WeightTransferUpdateInfo,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCTrainerSendWeightsArgs:
|
||||
"""Arguments for IPC trainer_send_weights method."""
|
||||
|
||||
mode: str
|
||||
"""Transport mode: 'http' or 'ray'."""
|
||||
llm_handle: Any = None
|
||||
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
|
||||
url: str | None = None
|
||||
"""Base URL for HTTP endpoint (required for 'http' mode)."""
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate that required arguments are provided for the selected mode."""
|
||||
if self.mode == "ray" and self.llm_handle is None:
|
||||
raise ValueError("llm_handle is required for 'ray' mode")
|
||||
if self.mode == "http" and self.url is None:
|
||||
raise ValueError("url is required for 'http' mode")
|
||||
if self.mode not in ("ray", "http"):
|
||||
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
"""Initialization info for IPC weight transfer backend. No init needed for IPC."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for IPC weight transfer backend.
|
||||
|
||||
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
|
||||
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
|
||||
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
|
||||
it is unpickled into ``ipc_handles`` during ``__post_init__``.
|
||||
"""
|
||||
|
||||
names: list[str]
|
||||
dtype_names: list[str]
|
||||
shapes: list[list[int]]
|
||||
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
|
||||
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
|
||||
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
|
||||
ipc_handles_pickled: str | None = None
|
||||
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ipc_handles_pickled is not None:
|
||||
if self.ipc_handles is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
|
||||
)
|
||||
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
|
||||
self.ipc_handles_pickled = None
|
||||
|
||||
if self.ipc_handles is None:
|
||||
raise ValueError(
|
||||
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
|
||||
)
|
||||
|
||||
num_params = len(self.names)
|
||||
if len(self.dtype_names) != num_params:
|
||||
raise ValueError(
|
||||
f"`dtype_names` should be of the same size as `names`: "
|
||||
f"got {len(self.dtype_names)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.shapes) != num_params:
|
||||
raise ValueError(
|
||||
f"`shapes` should be of the same size as `names`: "
|
||||
f"got {len(self.shapes)} and {len(self.names)}"
|
||||
)
|
||||
if len(self.ipc_handles) != num_params:
|
||||
raise ValueError(
|
||||
f"`ipc_handles` should be of the same size as `names`: "
|
||||
f"got {len(self.ipc_handles)} and {len(self.names)}"
|
||||
)
|
||||
|
||||
|
||||
class IPCWeightTransferEngine(
|
||||
WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo]
|
||||
):
|
||||
"""
|
||||
Weight transfer engine using CUDA IPC for communication between trainer and workers.
|
||||
|
||||
This implementation uses CUDA IPC to transfer weights from the trainer (rank 0)
|
||||
to all inference workers in a process group. IPC handles are used to share
|
||||
memory between processes on the same node.
|
||||
"""
|
||||
|
||||
# Define backend-specific dataclass types
|
||||
init_info_cls = IPCWeightTransferInitInfo
|
||||
update_info_cls = IPCWeightTransferUpdateInfo
|
||||
|
||||
def __init__(
|
||||
self, config: WeightTransferConfig, parallel_config: ParallelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the IPC weight transfer engine.
|
||||
|
||||
Args:
|
||||
config: The configuration for the weight transfer engine
|
||||
parallel_config: The configuration for the parallel setup
|
||||
"""
|
||||
super().__init__(config, parallel_config)
|
||||
|
||||
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
|
||||
"""
|
||||
Initialize the weight transfer mechanism.
|
||||
This is called once at the beginning of training.
|
||||
No initialization needed for IPC backend.
|
||||
|
||||
Args:
|
||||
init_info: IPC initialization info (empty)
|
||||
"""
|
||||
pass
|
||||
|
||||
def receive_weights(
|
||||
self,
|
||||
update_info: IPCWeightTransferUpdateInfo,
|
||||
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
|
||||
) -> None:
|
||||
"""
|
||||
Receive weights from the trainer via CUDA IPC handles.
|
||||
|
||||
Args:
|
||||
update_info: IPC update info containing parameter names, dtypes, shapes,
|
||||
and IPC handles. Each IPC handle is a mapping between physical
|
||||
GPU UUID and the IPC handle tuple (func, args).
|
||||
load_weights: Callable that loads weights into the model. Called
|
||||
incrementally for each weight to avoid OOM.
|
||||
"""
|
||||
assert update_info.ipc_handles is not None
|
||||
weights = []
|
||||
for name, _dtype_name, _shape, ipc_handle in zip(
|
||||
update_info.names,
|
||||
update_info.dtype_names,
|
||||
update_info.shapes,
|
||||
update_info.ipc_handles,
|
||||
):
|
||||
device_index = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
physical_gpu_id = str(props.uuid)
|
||||
|
||||
if physical_gpu_id not in ipc_handle:
|
||||
raise ValueError(
|
||||
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
|
||||
f"Available UUIDs: {list(ipc_handle.keys())}"
|
||||
)
|
||||
|
||||
handle = ipc_handle[physical_gpu_id]
|
||||
|
||||
func, args = handle
|
||||
list_args = list(args) # type: ignore
|
||||
# Index 6 is the device_index parameter in torch's
|
||||
# IPC handle tuple (rebuild_cuda_tensor). Update it
|
||||
# to the current device since the logical index can
|
||||
# differ between sender and receiver.
|
||||
list_args[6] = device_index
|
||||
weight = func(*list_args) # type: ignore
|
||||
weights.append((name, weight))
|
||||
|
||||
load_weights(weights)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""
|
||||
Shutdown the weight transfer engine.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""
|
||||
Send weights from trainer to inference workers via CUDA IPC.
|
||||
|
||||
Supports two modes:
|
||||
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
|
||||
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
|
||||
Tensors should be on the same GPU as the inference workers.
|
||||
trainer_args: Dictionary containing IPC-specific arguments.
|
||||
Should contain keys from IPCTrainerSendWeightsArgs:
|
||||
- mode: 'ray' or 'http'
|
||||
- llm_handle: Ray ObjectRef (for 'ray' mode)
|
||||
- url: Base URL string (for 'http' mode)
|
||||
|
||||
Example (Ray mode):
|
||||
>>> from vllm.distributed.weight_transfer.ipc_engine import (
|
||||
... IPCWeightTransferEngine,
|
||||
... IPCTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
|
||||
Example (HTTP mode):
|
||||
>>> args = IPCTrainerSendWeightsArgs(
|
||||
... mode="http", url="http://localhost:8000"
|
||||
... )
|
||||
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
|
||||
"""
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = IPCTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
# Get physical GPU UUID
|
||||
device_index = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_index)
|
||||
gpu_uuid = str(props.uuid)
|
||||
|
||||
# Collect weight metadata and create IPC handles
|
||||
names = []
|
||||
dtype_names = []
|
||||
shapes = []
|
||||
ipc_handles = []
|
||||
|
||||
for name, tensor in iterator:
|
||||
names.append(name)
|
||||
dtype_names.append(str(tensor.dtype).split(".")[-1])
|
||||
shapes.append(list(tensor.shape))
|
||||
|
||||
# Create IPC handle for this weight tensor
|
||||
# The tensor must remain in memory for IPC to work
|
||||
weight = tensor.detach().contiguous()
|
||||
ipc_handle = reduce_tensor(weight)
|
||||
ipc_handles.append({gpu_uuid: ipc_handle})
|
||||
|
||||
# Send weights based on mode
|
||||
if args.mode == "ray":
|
||||
# Ray mode: send via Ray RPC
|
||||
import ray
|
||||
|
||||
update_info = asdict(
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=names,
|
||||
dtype_names=dtype_names,
|
||||
shapes=shapes,
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
)
|
||||
ray.get(
|
||||
args.llm_handle.update_weights.remote(dict(update_info=update_info))
|
||||
)
|
||||
elif args.mode == "http":
|
||||
# HTTP mode: send via HTTP POST with pickled handles
|
||||
# Pickle and base64 encode IPC handles for HTTP transmission
|
||||
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
|
||||
"utf-8"
|
||||
)
|
||||
|
||||
url = f"{args.url}/update_weights"
|
||||
payload = {
|
||||
"update_info": {
|
||||
"names": names,
|
||||
"dtype_names": dtype_names,
|
||||
"shapes": shapes,
|
||||
"ipc_handles_pickled": pickled_handles,
|
||||
}
|
||||
}
|
||||
response = requests.post(url, json=payload, timeout=300)
|
||||
response.raise_for_status()
|
||||
@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
|
||||
world_size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLTrainerSendWeightsArgs:
|
||||
"""Arguments for NCCL trainer_send_weights method."""
|
||||
|
||||
group: Any
|
||||
"""Process group (PyNcclCommunicator) for NCCL communication."""
|
||||
src: int = 0
|
||||
"""Source rank (default 0, trainer is typically rank 0)."""
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
|
||||
"""Optional function to apply to each (name, tensor) pair before broadcasting.
|
||||
If None, extracts just the tensor."""
|
||||
packed: bool = False
|
||||
"""Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
stream: torch.cuda.Stream | None = None
|
||||
"""CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
"""Update info for NCCL weight transfer backend."""
|
||||
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
||||
When True, multiple tensors are batched together before broadcasting
|
||||
to reduce NCCL communication overhead."""
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
|
||||
"""Size in bytes for each packed tensor buffer. Default is 1GB.
|
||||
"""Size in bytes for each packed tensor buffer.
|
||||
Both producer and consumer must use the same value."""
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
|
||||
"""Number of buffers for double/triple buffering during packed transfer.
|
||||
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
|
||||
@staticmethod
|
||||
def trainer_send_weights(
|
||||
iterator: Iterator[tuple[str, torch.Tensor]],
|
||||
group: Any,
|
||||
src: int = 0,
|
||||
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
|
||||
| None = None,
|
||||
packed: bool = False,
|
||||
stream: torch.cuda.Stream | None = None,
|
||||
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
|
||||
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
|
||||
trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
|
||||
) -> None:
|
||||
"""Broadcast weights from trainer to vLLM workers.
|
||||
|
||||
Args:
|
||||
iterator: Iterator of model parameters. Returns (name, tensor) tuples
|
||||
group: Process group (PyNcclCommunicator)
|
||||
src: Source rank (default 0, trainer is typically rank 0)
|
||||
post_iter_func: Optional function to apply to each (name, tensor) pair
|
||||
before broadcasting. If None, extracts just the tensor.
|
||||
packed: Whether to use packed tensor broadcasting for efficiency.
|
||||
When True, multiple tensors are batched together before
|
||||
broadcasting to reduce NCCL communication overhead.
|
||||
stream: CUDA stream to use for broadcasting if packed is False.
|
||||
If packed is True, new streams will be created for each buffer.
|
||||
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
packed_num_buffers: Number of buffers for double/triple buffering.
|
||||
Must match the value used in NCCLWeightTransferUpdateInfo.
|
||||
trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
|
||||
NCCL-specific arguments. If a dict, should contain keys from
|
||||
NCCLTrainerSendWeightsArgs.
|
||||
|
||||
Example:
|
||||
>>> from vllm.distributed.weight_transfer.nccl_engine import (
|
||||
... NCCLWeightTransferEngine,
|
||||
... NCCLTrainerSendWeightsArgs,
|
||||
... )
|
||||
>>> param_iter = ((n, p) for n, p in model.named_parameters())
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(
|
||||
... param_iter, group, packed=True
|
||||
... )
|
||||
>>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
|
||||
>>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
|
||||
"""
|
||||
if post_iter_func is None:
|
||||
# Parse trainer args - accept either dict or dataclass instance
|
||||
if isinstance(trainer_args, dict):
|
||||
args = NCCLTrainerSendWeightsArgs(**trainer_args)
|
||||
else:
|
||||
args = trainer_args
|
||||
|
||||
if args.post_iter_func is None:
|
||||
# Default: extract just the tensor from (name, tensor) tuple
|
||||
post_iter_func = lambda x: x[1]
|
||||
else:
|
||||
post_iter_func = args.post_iter_func
|
||||
|
||||
if packed:
|
||||
if args.packed:
|
||||
# Use packed tensor broadcasting for efficiency
|
||||
from vllm.distributed.weight_transfer.packed_tensor import (
|
||||
packed_broadcast_producer,
|
||||
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
|
||||
|
||||
packed_broadcast_producer(
|
||||
iterator=iterator,
|
||||
group=group,
|
||||
src=src,
|
||||
group=args.group,
|
||||
src=args.src,
|
||||
post_iter_func=post_iter_func,
|
||||
buffer_size_bytes=packed_buffer_size_bytes,
|
||||
num_buffers=packed_num_buffers,
|
||||
buffer_size_bytes=args.packed_buffer_size_bytes,
|
||||
num_buffers=args.packed_num_buffers,
|
||||
)
|
||||
else:
|
||||
# Use simple one-by-one broadcasting
|
||||
for item in iterator:
|
||||
tensor = post_iter_func(item)
|
||||
group.broadcast(
|
||||
tensor, src=src, stream=stream or torch.cuda.current_stream()
|
||||
args.group.broadcast(
|
||||
tensor,
|
||||
src=args.src,
|
||||
stream=args.stream or torch.cuda.current_stream(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -419,6 +419,7 @@ class EngineArgs:
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
moe_backend: MoEBackend = KernelConfig.moe_backend
|
||||
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
|
||||
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
ubatch_size: int = ParallelConfig.ubatch_size
|
||||
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
||||
@@ -896,6 +897,9 @@ class EngineArgs:
|
||||
"--ubatch-size",
|
||||
**parallel_kwargs["ubatch_size"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--dbo-decode-token-threshold",
|
||||
**parallel_kwargs["dbo_decode_token_threshold"],
|
||||
@@ -1321,6 +1325,7 @@ class EngineArgs:
|
||||
"launched vLLM.",
|
||||
self.seed,
|
||||
)
|
||||
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
model_weights=self.model_weights,
|
||||
@@ -1697,6 +1702,7 @@ class EngineArgs:
|
||||
is_moe_model=model_config.is_moe,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
all2all_backend=self.all2all_backend,
|
||||
enable_elastic_ep=self.enable_elastic_ep,
|
||||
enable_dbo=self.enable_dbo,
|
||||
ubatch_size=self.ubatch_size,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
@@ -1905,6 +1911,7 @@ class EngineArgs:
|
||||
performance_mode=self.performance_mode,
|
||||
weight_transfer_config=self.weight_transfer_config,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def _check_feature_supported(self):
|
||||
@@ -2074,20 +2081,19 @@ class EngineArgs:
|
||||
)
|
||||
|
||||
# Disable chunked prefill and prefix caching for:
|
||||
# POWER (ppc64le)/RISCV CPUs in V1
|
||||
# RISCV CPUs in V1
|
||||
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
|
||||
CpuArchEnum.POWERPC,
|
||||
CpuArchEnum.RISCV,
|
||||
):
|
||||
logger.info(
|
||||
"Chunked prefill is not supported for POWER, "
|
||||
"and RISC-V CPUs; "
|
||||
"Chunked prefill is not supported for"
|
||||
"RISC-V CPUs; "
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_chunked_prefill = False
|
||||
logger.info(
|
||||
"Prefix caching is not supported for POWER, "
|
||||
"and RISC-V CPUs; "
|
||||
"Prefix caching is not supported for "
|
||||
"RISC-V CPUs; "
|
||||
"disabling it for V1 backend."
|
||||
)
|
||||
self.enable_prefix_caching = False
|
||||
@@ -2181,14 +2187,10 @@ class AsyncEngineArgs(EngineArgs):
|
||||
"--enable-log-requests",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=AsyncEngineArgs.enable_log_requests,
|
||||
help="Enable logging requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-log-requests",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=not AsyncEngineArgs.enable_log_requests,
|
||||
help="[DEPRECATED] Disable logging requests.",
|
||||
deprecated=True,
|
||||
help="Enable logging request information, dependant on log level:\n"
|
||||
"- INFO: Request ID, parameters and LoRA request.\n"
|
||||
"- DEBUG: Prompt inputs (e.g: text, token IDs).\n"
|
||||
"You can set the minimum log level via `VLLM_LOGGING_LEVEL`.",
|
||||
)
|
||||
current_platform.pre_register_and_update(parser)
|
||||
return parser
|
||||
|
||||
@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicCountTokensRequest,
|
||||
AnthropicCountTokensResponse,
|
||||
AnthropicError,
|
||||
AnthropicErrorResponse,
|
||||
AnthropicMessagesRequest,
|
||||
@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages:
|
||||
return request.app.state.anthropic_serving_messages
|
||||
|
||||
|
||||
def translate_error_response(response: ErrorResponse) -> JSONResponse:
|
||||
anthropic_error = AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type=response.error.type,
|
||||
message=response.error.message,
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=response.error.code, content=anthropic_error.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages:
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
|
||||
def translate_error_response(response: ErrorResponse) -> JSONResponse:
|
||||
anthropic_error = AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type=response.error.type,
|
||||
message=response.error.message,
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=response.error.code, content=anthropic_error.model_dump()
|
||||
)
|
||||
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages/count_tokens",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
|
||||
},
|
||||
)
|
||||
@load_aware_call
|
||||
@with_cancellation
|
||||
async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request):
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
error = base_server.create_error_response(
|
||||
message="The model does not support Messages API"
|
||||
)
|
||||
return translate_error_response(error)
|
||||
|
||||
try:
|
||||
response = await handler.count_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
logger.exception("Error in count_tokens: %s", e)
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
content=AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type="internal_error",
|
||||
message=str(e),
|
||||
)
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if isinstance(response, ErrorResponse):
|
||||
return translate_error_response(response)
|
||||
|
||||
return JSONResponse(content=response.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
@@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel):
|
||||
class AnthropicContentBlock(BaseModel):
|
||||
"""Content block in message"""
|
||||
|
||||
type: Literal["text", "image", "tool_use", "tool_result"]
|
||||
type: Literal["text", "image", "tool_use", "tool_result", "thinking"]
|
||||
text: str | None = None
|
||||
# For image content
|
||||
source: dict[str, Any] | None = None
|
||||
@@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel):
|
||||
input: dict[str, Any] | None = None
|
||||
content: str | list[dict[str, Any]] | None = None
|
||||
is_error: bool | None = None
|
||||
# For thinking content
|
||||
thinking: str | None = None
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
@@ -74,7 +77,7 @@ class AnthropicTool(BaseModel):
|
||||
class AnthropicToolChoice(BaseModel):
|
||||
"""Tool Choice definition"""
|
||||
|
||||
type: Literal["auto", "any", "tool"]
|
||||
type: Literal["auto", "any", "tool", "none"]
|
||||
name: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
@@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel):
|
||||
class AnthropicDelta(BaseModel):
|
||||
"""Delta for streaming responses"""
|
||||
|
||||
type: Literal["text_delta", "input_json_delta"] | None = None
|
||||
type: (
|
||||
Literal["text_delta", "input_json_delta", "thinking_delta", "signature_delta"]
|
||||
| None
|
||||
) = None
|
||||
text: str | None = None
|
||||
thinking: str | None = None
|
||||
partial_json: str | None = None
|
||||
signature: str | None = None
|
||||
|
||||
# Message delta
|
||||
stop_reason: (
|
||||
@@ -167,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel):
|
||||
def model_post_init(self, __context):
|
||||
if not self.id:
|
||||
self.id = f"msg_{int(time.time() * 1000)}"
|
||||
|
||||
|
||||
class AnthropicContextManagement(BaseModel):
|
||||
"""Context management information for token counting."""
|
||||
|
||||
original_input_tokens: int
|
||||
|
||||
|
||||
class AnthropicCountTokensRequest(BaseModel):
|
||||
"""Anthropic messages.count_tokens request"""
|
||||
|
||||
model: str
|
||||
messages: list[AnthropicMessage]
|
||||
system: str | list[AnthropicContentBlock] | None = None
|
||||
tool_choice: AnthropicToolChoice | None = None
|
||||
tools: list[AnthropicTool] | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Model is required")
|
||||
return v
|
||||
|
||||
|
||||
class AnthropicCountTokensResponse(BaseModel):
|
||||
"""Anthropic messages.count_tokens response"""
|
||||
|
||||
input_tokens: int
|
||||
context_management: AnthropicContextManagement | None = None
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
@@ -16,6 +17,9 @@ from fastapi import Request
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicContentBlock,
|
||||
AnthropicContextManagement,
|
||||
AnthropicCountTokensRequest,
|
||||
AnthropicCountTokensResponse,
|
||||
AnthropicDelta,
|
||||
AnthropicError,
|
||||
AnthropicMessagesRequest,
|
||||
@@ -85,14 +89,52 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
"tool_calls": "tool_use",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _convert_image_source_to_url(source: dict[str, Any]) -> str:
|
||||
"""Convert an Anthropic image source to an OpenAI-compatible URL.
|
||||
|
||||
Anthropic supports two image source types:
|
||||
- base64: {"type": "base64", "media_type": "image/jpeg", "data": "..."}
|
||||
- url: {"type": "url", "url": "https://..."}
|
||||
|
||||
For base64 sources, this constructs a proper data URI that
|
||||
downstream processors (e.g. vLLM's media connector) can handle.
|
||||
"""
|
||||
source_type = source.get("type")
|
||||
if source_type == "url":
|
||||
return source.get("url", "")
|
||||
# Default to base64 processing if type is "base64"
|
||||
# or missing, ensuring a proper data URI is always
|
||||
# constructed for non-URL sources.
|
||||
media_type = source.get("media_type", "image/jpeg")
|
||||
data = source.get("data", "")
|
||||
return f"data:{media_type};base64,{data}"
|
||||
|
||||
@classmethod
|
||||
def _convert_anthropic_to_openai_request(
|
||||
self, anthropic_request: AnthropicMessagesRequest
|
||||
cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest
|
||||
) -> ChatCompletionRequest:
|
||||
"""Convert Anthropic message format to OpenAI format"""
|
||||
openai_messages = []
|
||||
openai_messages: list[dict[str, Any]] = []
|
||||
|
||||
cls._convert_system_message(anthropic_request, openai_messages)
|
||||
cls._convert_messages(anthropic_request.messages, openai_messages)
|
||||
req = cls._build_base_request(anthropic_request, openai_messages)
|
||||
cls._handle_streaming_options(req, anthropic_request)
|
||||
cls._convert_tool_choice(anthropic_request, req)
|
||||
cls._convert_tools(anthropic_request, req)
|
||||
return req
|
||||
|
||||
@classmethod
|
||||
def _convert_system_message(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert Anthropic system message to OpenAI format"""
|
||||
if not anthropic_request.system:
|
||||
return
|
||||
|
||||
# Add system message if provided
|
||||
if anthropic_request.system:
|
||||
if isinstance(anthropic_request.system, str):
|
||||
openai_messages.append(
|
||||
{"role": "system", "content": anthropic_request.system}
|
||||
@@ -104,27 +146,83 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
system_prompt += block.text
|
||||
openai_messages.append({"role": "system", "content": system_prompt})
|
||||
|
||||
for msg in anthropic_request.messages:
|
||||
@classmethod
|
||||
def _convert_messages(
|
||||
cls, messages: list, openai_messages: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert Anthropic messages to OpenAI format"""
|
||||
for msg in messages:
|
||||
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
|
||||
|
||||
if isinstance(msg.content, str):
|
||||
openai_msg["content"] = msg.content
|
||||
else:
|
||||
# Handle complex content blocks
|
||||
cls._convert_message_content(msg, openai_msg, openai_messages)
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
@classmethod
|
||||
def _convert_message_content(
|
||||
cls,
|
||||
msg,
|
||||
openai_msg: dict[str, Any],
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert complex message content blocks"""
|
||||
content_parts: list[dict[str, Any]] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
reasoning_parts: list[str] = []
|
||||
|
||||
for block in msg.content:
|
||||
cls._convert_block(
|
||||
block,
|
||||
msg.role,
|
||||
content_parts,
|
||||
tool_calls,
|
||||
reasoning_parts,
|
||||
openai_messages,
|
||||
)
|
||||
|
||||
if reasoning_parts:
|
||||
openai_msg["reasoning"] = "".join(reasoning_parts)
|
||||
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls and not reasoning_parts:
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def _convert_block(
|
||||
cls,
|
||||
block,
|
||||
role: str,
|
||||
content_parts: list[dict[str, Any]],
|
||||
tool_calls: list[dict[str, Any]],
|
||||
reasoning_parts: list[str],
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert individual content block"""
|
||||
if block.type == "text" and block.text:
|
||||
content_parts.append({"type": "text", "text": block.text})
|
||||
elif block.type == "image" and block.source:
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": block.source.get("data", "")},
|
||||
}
|
||||
)
|
||||
image_url = cls._convert_image_source_to_url(block.source)
|
||||
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||
elif block.type == "thinking" and block.thinking is not None:
|
||||
reasoning_parts.append(block.thinking)
|
||||
elif block.type == "tool_use":
|
||||
# Convert tool use to function call format
|
||||
cls._convert_tool_use_block(block, tool_calls)
|
||||
elif block.type == "tool_result":
|
||||
cls._convert_tool_result_block(block, role, openai_messages, content_parts)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None:
|
||||
"""Convert tool_use block to OpenAI function call format"""
|
||||
tool_call = {
|
||||
"id": block.id or f"call_{int(time.time())}",
|
||||
"type": "function",
|
||||
@@ -134,45 +232,82 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
elif block.type == "tool_result":
|
||||
if msg.role == "user":
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_result_block(
|
||||
cls,
|
||||
block,
|
||||
role: str,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
content_parts: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Convert tool_result block to OpenAI format"""
|
||||
if role == "user":
|
||||
cls._convert_user_tool_result(block, openai_messages)
|
||||
else:
|
||||
tool_result_text = str(block.content) if block.content else ""
|
||||
content_parts.append(
|
||||
{"type": "text", "text": f"Tool result: {tool_result_text}"}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_user_tool_result(
|
||||
cls, block, openai_messages: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Convert user tool_result with text and image support"""
|
||||
tool_text = ""
|
||||
tool_image_urls: list[str] = []
|
||||
|
||||
if isinstance(block.content, str):
|
||||
tool_text = block.content
|
||||
elif isinstance(block.content, list):
|
||||
text_parts: list[str] = []
|
||||
for item in block.content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item_type == "image":
|
||||
source = item.get("source", {})
|
||||
url = cls._convert_image_source_to_url(source)
|
||||
if url:
|
||||
tool_image_urls.append(url)
|
||||
tool_text = "\n".join(text_parts)
|
||||
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.tool_use_id or "",
|
||||
"content": str(block.content)
|
||||
if block.content
|
||||
else "",
|
||||
"content": tool_text or "",
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Assistant tool result becomes regular text
|
||||
tool_result_text = (
|
||||
str(block.content) if block.content else ""
|
||||
)
|
||||
content_parts.append(
|
||||
|
||||
if tool_image_urls:
|
||||
openai_messages.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Tool result: {tool_result_text}",
|
||||
"role": "user",
|
||||
"content": [ # type: ignore[dict-item]
|
||||
{"type": "image_url", "image_url": {"url": img}}
|
||||
for img in tool_image_urls
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Add tool calls to the message if any
|
||||
if tool_calls:
|
||||
openai_msg["tool_calls"] = tool_calls # type: ignore
|
||||
@classmethod
|
||||
def _build_base_request(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
openai_messages: list[dict[str, Any]],
|
||||
) -> ChatCompletionRequest:
|
||||
"""Build base ChatCompletionRequest"""
|
||||
if isinstance(anthropic_request, AnthropicCountTokensRequest):
|
||||
return ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
)
|
||||
|
||||
# Add content parts if any
|
||||
if content_parts:
|
||||
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
|
||||
openai_msg["content"] = content_parts[0]["text"]
|
||||
else:
|
||||
openai_msg["content"] = content_parts # type: ignore
|
||||
elif not tool_calls:
|
||||
continue
|
||||
|
||||
openai_messages.append(openai_msg)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
return ChatCompletionRequest(
|
||||
model=anthropic_request.model,
|
||||
messages=openai_messages,
|
||||
max_tokens=anthropic_request.max_tokens,
|
||||
@@ -183,19 +318,40 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
top_k=anthropic_request.top_k,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_streaming_options(
|
||||
cls,
|
||||
req: ChatCompletionRequest,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
) -> None:
|
||||
"""Handle streaming configuration"""
|
||||
if isinstance(anthropic_request, AnthropicCountTokensRequest):
|
||||
return
|
||||
if anthropic_request.stream:
|
||||
req.stream = anthropic_request.stream
|
||||
req.stream_options = StreamOptions.validate(
|
||||
req.stream_options = StreamOptions.model_validate(
|
||||
{"include_usage": True, "continuous_usage_stats": True}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _convert_tool_choice(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
req: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Convert Anthropic tool_choice to OpenAI format"""
|
||||
if anthropic_request.tool_choice is None:
|
||||
req.tool_choice = None
|
||||
elif anthropic_request.tool_choice.type == "auto":
|
||||
return
|
||||
|
||||
tool_choice_type = anthropic_request.tool_choice.type
|
||||
if tool_choice_type == "auto":
|
||||
req.tool_choice = "auto"
|
||||
elif anthropic_request.tool_choice.type == "any":
|
||||
elif tool_choice_type == "any":
|
||||
req.tool_choice = "required"
|
||||
elif anthropic_request.tool_choice.type == "tool":
|
||||
elif tool_choice_type == "none":
|
||||
req.tool_choice = "none"
|
||||
elif tool_choice_type == "tool":
|
||||
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
|
||||
{
|
||||
"type": "function",
|
||||
@@ -203,9 +359,17 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
}
|
||||
)
|
||||
|
||||
tools = []
|
||||
@classmethod
|
||||
def _convert_tools(
|
||||
cls,
|
||||
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
|
||||
req: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Convert Anthropic tools to OpenAI format"""
|
||||
if anthropic_request.tools is None:
|
||||
return req
|
||||
return
|
||||
|
||||
tools = []
|
||||
for tool in anthropic_request.tools:
|
||||
tools.append(
|
||||
ChatCompletionToolsParam.model_validate(
|
||||
@@ -219,10 +383,10 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if req.tool_choice is None:
|
||||
req.tool_choice = "auto"
|
||||
req.tools = tools
|
||||
return req
|
||||
|
||||
async def create_messages(
|
||||
self,
|
||||
@@ -263,23 +427,32 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
output_tokens=generator.usage.completion_tokens,
|
||||
),
|
||||
)
|
||||
if generator.choices[0].finish_reason == "stop":
|
||||
choice = generator.choices[0]
|
||||
if choice.finish_reason == "stop":
|
||||
result.stop_reason = "end_turn"
|
||||
elif generator.choices[0].finish_reason == "length":
|
||||
elif choice.finish_reason == "length":
|
||||
result.stop_reason = "max_tokens"
|
||||
elif generator.choices[0].finish_reason == "tool_calls":
|
||||
elif choice.finish_reason == "tool_calls":
|
||||
result.stop_reason = "tool_use"
|
||||
|
||||
content: list[AnthropicContentBlock] = [
|
||||
content: list[AnthropicContentBlock] = []
|
||||
if choice.message.reasoning:
|
||||
content.append(
|
||||
AnthropicContentBlock(
|
||||
type="thinking",
|
||||
thinking=choice.message.reasoning,
|
||||
signature=uuid.uuid4().hex,
|
||||
)
|
||||
)
|
||||
if choice.message.content:
|
||||
content.append(
|
||||
AnthropicContentBlock(
|
||||
type="text",
|
||||
text=generator.choices[0].message.content
|
||||
if generator.choices[0].message.content
|
||||
else "",
|
||||
text=choice.message.content,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
for tool_call in generator.choices[0].message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
anthropic_tool_call = AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
@@ -297,10 +470,85 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
generator: AsyncGenerator[str, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
try:
|
||||
|
||||
class _ActiveBlockState:
|
||||
def __init__(self) -> None:
|
||||
self.content_block_index = 0
|
||||
self.block_type: str | None = None
|
||||
self.block_index: int | None = None
|
||||
self.block_signature: str | None = None
|
||||
self.signature_emitted: bool = False
|
||||
self.tool_use_id: str | None = None
|
||||
|
||||
def reset(self) -> None:
|
||||
self.block_type = None
|
||||
self.block_index = None
|
||||
self.block_signature = None
|
||||
self.signature_emitted = False
|
||||
self.tool_use_id = None
|
||||
|
||||
def start(self, block: AnthropicContentBlock) -> None:
|
||||
self.block_type = block.type
|
||||
self.block_index = self.content_block_index
|
||||
if block.type == "thinking":
|
||||
self.block_signature = uuid.uuid4().hex
|
||||
self.signature_emitted = False
|
||||
self.tool_use_id = None
|
||||
elif block.type == "tool_use":
|
||||
self.block_signature = None
|
||||
self.signature_emitted = True
|
||||
self.tool_use_id = block.id
|
||||
else:
|
||||
self.block_signature = None
|
||||
self.signature_emitted = True
|
||||
self.tool_use_id = None
|
||||
|
||||
first_item = True
|
||||
finish_reason = None
|
||||
content_block_index = 0
|
||||
content_block_started = False
|
||||
state = _ActiveBlockState()
|
||||
# Map from tool call index to tool_use_id
|
||||
tool_index_to_id: dict[int, str] = {}
|
||||
|
||||
def stop_active_block():
|
||||
events: list[str] = []
|
||||
if state.block_type is None:
|
||||
return events
|
||||
if (
|
||||
state.block_type == "thinking"
|
||||
and state.block_signature is not None
|
||||
and not state.signature_emitted
|
||||
):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=state.block_index,
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="signature_delta",
|
||||
signature=state.block_signature,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
events.append(wrap_data_with_event(data, "content_block_delta"))
|
||||
state.signature_emitted = True
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=state.block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(exclude_unset=True)
|
||||
events.append(wrap_data_with_event(data, "content_block_stop"))
|
||||
state.reset()
|
||||
state.content_block_index += 1
|
||||
return events
|
||||
|
||||
def start_block(block: AnthropicContentBlock):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=state.content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=block,
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
event = wrap_data_with_event(data, "content_block_start")
|
||||
state.start(block)
|
||||
return event
|
||||
|
||||
async for item in generator:
|
||||
if item.startswith("data:"):
|
||||
@@ -326,6 +574,8 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
id=origin_chunk.id,
|
||||
content=[],
|
||||
model=origin_chunk.model,
|
||||
stop_reason=None,
|
||||
stop_sequence=None,
|
||||
usage=AnthropicUsage(
|
||||
input_tokens=origin_chunk.usage.prompt_tokens
|
||||
if origin_chunk.usage
|
||||
@@ -341,13 +591,8 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
|
||||
# last chunk including usage info
|
||||
if len(origin_chunk.choices) == 0:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
)
|
||||
data = stop_chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_stop")
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
stop_reason = self.stop_reason_map.get(
|
||||
finish_reason or "stop"
|
||||
)
|
||||
@@ -369,26 +614,55 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
|
||||
if origin_chunk.choices[0].finish_reason is not None:
|
||||
finish_reason = origin_chunk.choices[0].finish_reason
|
||||
continue
|
||||
# continue
|
||||
|
||||
# content
|
||||
if origin_chunk.choices[0].delta.content is not None:
|
||||
if not content_block_started:
|
||||
# thinking / text content
|
||||
reasoning_delta = origin_chunk.choices[0].delta.reasoning
|
||||
if reasoning_delta is not None:
|
||||
if reasoning_delta == "":
|
||||
pass
|
||||
else:
|
||||
if state.block_type != "thinking":
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
start_event = start_block(
|
||||
AnthropicContentBlock(
|
||||
type="thinking", thinking=""
|
||||
)
|
||||
)
|
||||
yield start_event
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
type="text", text=""
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="thinking_delta",
|
||||
thinking=reasoning_delta,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
|
||||
if origin_chunk.choices[0].delta.content is not None:
|
||||
if origin_chunk.choices[0].delta.content == "":
|
||||
continue
|
||||
pass
|
||||
else:
|
||||
if state.block_type != "text":
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
start_event = start_block(
|
||||
AnthropicContentBlock(type="text", text="")
|
||||
)
|
||||
yield start_event
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="text_delta",
|
||||
@@ -397,44 +671,47 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
continue
|
||||
|
||||
# tool calls
|
||||
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
|
||||
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
|
||||
# tool calls - process all tool calls in the delta
|
||||
if len(origin_chunk.choices[0].delta.tool_calls) > 0:
|
||||
for tool_call in origin_chunk.choices[0].delta.tool_calls:
|
||||
if tool_call.id is not None:
|
||||
if content_block_started:
|
||||
stop_chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_stop",
|
||||
# Update mapping for incremental updates
|
||||
tool_index_to_id[tool_call.index] = tool_call.id
|
||||
# Only create new block if different tool call
|
||||
# AND has a name
|
||||
tool_name = (
|
||||
tool_call.function.name
|
||||
if tool_call.function
|
||||
else None
|
||||
)
|
||||
data = stop_chunk.model_dump_json(
|
||||
exclude_unset=True
|
||||
)
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_stop"
|
||||
)
|
||||
content_block_started = False
|
||||
content_block_index += 1
|
||||
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
type="content_block_start",
|
||||
content_block=AnthropicContentBlock(
|
||||
if (
|
||||
state.tool_use_id != tool_call.id
|
||||
and tool_name is not None
|
||||
):
|
||||
for event in stop_active_block():
|
||||
yield event
|
||||
start_event = start_block(
|
||||
AnthropicContentBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name
|
||||
if tool_call.function
|
||||
else None,
|
||||
name=tool_name,
|
||||
input={},
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_start")
|
||||
content_block_started = True
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
)
|
||||
yield start_event
|
||||
# Handle initial arguments if present
|
||||
if (
|
||||
tool_call.function
|
||||
and tool_call.function.arguments
|
||||
and state.tool_use_id == tool_call.id
|
||||
):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
@@ -445,20 +722,31 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_delta"
|
||||
)
|
||||
|
||||
else:
|
||||
# Incremental update - use index to find tool_use_id
|
||||
tool_use_id = tool_index_to_id.get(tool_call.index)
|
||||
if (
|
||||
tool_use_id is not None
|
||||
and tool_call.function
|
||||
and tool_call.function.arguments
|
||||
and state.tool_use_id == tool_use_id
|
||||
):
|
||||
chunk = AnthropicStreamEvent(
|
||||
index=content_block_index,
|
||||
index=(
|
||||
state.block_index
|
||||
if state.block_index is not None
|
||||
else state.content_block_index
|
||||
),
|
||||
type="content_block_delta",
|
||||
delta=AnthropicDelta(
|
||||
type="input_json_delta",
|
||||
partial_json=tool_call.function.arguments
|
||||
if tool_call.function
|
||||
else None,
|
||||
partial_json=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "content_block_delta")
|
||||
yield wrap_data_with_event(
|
||||
data, "content_block_delta"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
error_response = AnthropicStreamEvent(
|
||||
@@ -481,3 +769,31 @@ class AnthropicServingMessages(OpenAIServingChat):
|
||||
data = error_response.model_dump_json(exclude_unset=True)
|
||||
yield wrap_data_with_event(data, "error")
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
request: AnthropicCountTokensRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AnthropicCountTokensResponse | ErrorResponse:
|
||||
"""Implements Anthropic's messages.count_tokens endpoint."""
|
||||
chat_req = self._convert_anthropic_to_openai_request(request)
|
||||
result = await self.render_chat_request(chat_req)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
_, engine_prompts = result
|
||||
|
||||
input_tokens = sum( # type: ignore
|
||||
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
|
||||
for prompt in engine_prompts
|
||||
if "prompt_token_ids" in prompt
|
||||
)
|
||||
|
||||
response = AnthropicCountTokensResponse(
|
||||
input_tokens=input_tokens,
|
||||
context_management=AnthropicContextManagement(
|
||||
original_input_tokens=input_tokens
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -7,6 +7,7 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTemplateConfig:
|
||||
chat_template: str | None = None
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
trust_request_chat_template: bool = False
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Path | str | None):
|
||||
"""Raises if the provided chat template appears invalid."""
|
||||
if chat_template is None:
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
|
||||
from vllm.entrypoints.cli.benchmark.mm_processor import (
|
||||
BenchmarkMMProcessorSubcommand,
|
||||
)
|
||||
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
|
||||
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
|
||||
|
||||
# Keep this package init import-free.
|
||||
#
|
||||
# The `vllm` console script imports `vllm.entrypoints.cli.main`, which causes
|
||||
# Python to import this package before loading the `main` submodule.
|
||||
# Eagerly importing benchmark subcommands here makes every `vllm serve ...`
|
||||
# startup depend on optional benchmark-only modules.
|
||||
#
|
||||
# Benchmark subcommands are loaded on demand in
|
||||
# `vllm.entrypoints.cli.benchmark.main`.
|
||||
__all__: list[str] = [
|
||||
"BenchmarkLatencySubcommand",
|
||||
"BenchmarkMMProcessorSubcommand",
|
||||
"BenchmarkServingSubcommand",
|
||||
"BenchmarkStartupSubcommand",
|
||||
"BenchmarkSweepSubcommand",
|
||||
"BenchmarkThroughputSubcommand",
|
||||
]
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import typing
|
||||
|
||||
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
|
||||
@@ -15,30 +13,6 @@ if typing.TYPE_CHECKING:
|
||||
else:
|
||||
FlexibleArgumentParser = argparse.ArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_benchmark_subcommands() -> None:
|
||||
modules = [
|
||||
"vllm.entrypoints.cli.benchmark.latency",
|
||||
"vllm.entrypoints.cli.benchmark.mm_processor",
|
||||
"vllm.entrypoints.cli.benchmark.serve",
|
||||
"vllm.entrypoints.cli.benchmark.startup",
|
||||
"vllm.entrypoints.cli.benchmark.sweep",
|
||||
"vllm.entrypoints.cli.benchmark.throughput",
|
||||
]
|
||||
|
||||
for module_name in modules:
|
||||
try:
|
||||
importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.warning(
|
||||
"Skipping benchmark subcommand module %s because an optional "
|
||||
"dependency could not be imported: %r",
|
||||
module_name,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkSubcommand(CLISubcommand):
|
||||
"""The `bench` subcommand for the vLLM CLI."""
|
||||
@@ -64,8 +38,6 @@ class BenchmarkSubcommand(CLISubcommand):
|
||||
)
|
||||
bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type")
|
||||
|
||||
_load_benchmark_subcommands()
|
||||
|
||||
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
|
||||
cmd_subparser = bench_subparsers.add_parser(
|
||||
cmd_cls.name,
|
||||
|
||||
@@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
num_api_servers: int = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
|
||||
# TODO(wentao): remove this once well tested
|
||||
raise ValueError(
|
||||
"--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
|
||||
)
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
@@ -246,8 +252,12 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
api_server_manager: APIServerProcessManager | None = None
|
||||
|
||||
from vllm.v1.engine.utils import get_engine_zmq_addresses
|
||||
|
||||
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
|
||||
|
||||
with launch_core_engines(
|
||||
vllm_config, executor_class, log_stats, num_api_servers
|
||||
vllm_config, executor_class, log_stats, addresses, num_api_servers
|
||||
) as (local_engine_manager, coordinator, addresses):
|
||||
# Construct common args for the APIServerProcessManager up-front.
|
||||
api_server_manager_kwargs = dict(
|
||||
|
||||
@@ -101,11 +101,15 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
||||
sampling_params = self._sampling_params_from_proto(
|
||||
request.sampling_params, stream=request.stream
|
||||
)
|
||||
tokenization_kwargs = self._tokenization_kwargs_from_proto(
|
||||
request.sampling_params
|
||||
)
|
||||
|
||||
async for output in self.async_llm.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
):
|
||||
# Convert vLLM output to protobuf
|
||||
# For streaming, always send chunks
|
||||
@@ -308,9 +312,6 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
||||
seed=params.seed if params.HasField("seed") else None,
|
||||
include_stop_str_in_output=params.include_stop_str_in_output,
|
||||
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
|
||||
truncate_prompt_tokens=params.truncate_prompt_tokens
|
||||
if params.HasField("truncate_prompt_tokens")
|
||||
else None,
|
||||
structured_outputs=structured_outputs,
|
||||
# detokenize must be True if stop strings are used
|
||||
detokenize=bool(stop),
|
||||
@@ -319,6 +320,14 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _tokenization_kwargs_from_proto(
|
||||
params: vllm_engine_pb2.SamplingParams,
|
||||
) -> dict[str, int] | None:
|
||||
if params.HasField("truncate_prompt_tokens"):
|
||||
return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
|
||||
"""
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cloudpickle
|
||||
@@ -41,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateConfig,
|
||||
ChatTemplateContentFormatOption,
|
||||
load_chat_template,
|
||||
)
|
||||
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreData,
|
||||
ScoreMultiModalParam,
|
||||
@@ -146,6 +149,7 @@ class LLM:
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id.
|
||||
chat_template: The chat template to apply.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
@@ -233,6 +237,7 @@ class LLM:
|
||||
quantization: QuantizationMethods | None = None,
|
||||
revision: str | None = None,
|
||||
tokenizer_revision: str | None = None,
|
||||
chat_template: Path | str | None = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
@@ -385,9 +390,16 @@ class LLM:
|
||||
|
||||
self.model_config = self.llm_engine.model_config
|
||||
self.renderer = self.llm_engine.renderer
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
self.io_processor = self.llm_engine.io_processor
|
||||
self.input_processor = self.llm_engine.input_processor
|
||||
|
||||
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
|
||||
self.init_pooling_io_processors = init_pooling_io_processors(
|
||||
supported_tasks=supported_tasks,
|
||||
model_config=self.model_config,
|
||||
renderer=self.renderer,
|
||||
chat_template_config=self.chat_template_config,
|
||||
)
|
||||
# Cache for __repr__ to avoid repeated collective_rpc calls
|
||||
self._cached_repr: str | None = None
|
||||
|
||||
@@ -1030,7 +1042,6 @@ class LLM:
|
||||
prompts: PromptType | Sequence[PromptType] | DataPrompt,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
*,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
pooling_task: PoolingTask | None = None,
|
||||
@@ -1088,21 +1099,7 @@ class LLM:
|
||||
"pooling model."
|
||||
)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
warnings.warn(
|
||||
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
|
||||
"is deprecated and will be removed in v0.16. "
|
||||
"Please pass it via `tokenization_kwargs` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
@@ -1136,6 +1133,31 @@ class LLM:
|
||||
for p in params_seq:
|
||||
if p.task is None:
|
||||
p.task = "plugin"
|
||||
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
output_type=PoolingRequestOutput,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
)
|
||||
]
|
||||
else:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
@@ -1153,6 +1175,28 @@ class LLM:
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if pooling_task in self.init_pooling_io_processors:
|
||||
io_processor = self.init_pooling_io_processors[pooling_task]
|
||||
processor_inputs = io_processor.pre_process_offline(
|
||||
prompts_seq, tokenization_kwargs
|
||||
)
|
||||
seq_lora_requests = self._lora_request_to_seq(
|
||||
lora_request, len(prompts_seq)
|
||||
)
|
||||
seq_priority = self._priority_to_seq(None, len(prompts))
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=processor_inputs,
|
||||
params=params_seq,
|
||||
lora_requests=seq_lora_requests,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(
|
||||
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
|
||||
)
|
||||
outputs = io_processor.post_process(outputs)
|
||||
else:
|
||||
outputs = self._run_completion(
|
||||
prompts=prompts_seq,
|
||||
params=params_seq,
|
||||
@@ -1161,31 +1205,12 @@ class LLM:
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if use_io_processor:
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](
|
||||
request_id="",
|
||||
outputs=processed_outputs,
|
||||
num_cached_tokens=getattr(
|
||||
processed_outputs, "num_cached_tokens", 0
|
||||
),
|
||||
prompt_token_ids=[],
|
||||
finished=True,
|
||||
)
|
||||
]
|
||||
|
||||
return outputs
|
||||
|
||||
def embed(
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
*,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
@@ -1221,12 +1246,6 @@ class LLM:
|
||||
"Try converting the model using `--convert embed`."
|
||||
)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs = merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
items = self.encode(
|
||||
prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
@@ -1294,7 +1313,6 @@ class LLM:
|
||||
/,
|
||||
*,
|
||||
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
|
||||
truncate_prompt_tokens: int | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
@@ -1319,13 +1337,11 @@ class LLM:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
pooled hidden states in the same order as the input prompts.
|
||||
"""
|
||||
|
||||
return self.encode(
|
||||
prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
pooling_params=pooling_params,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
pooling_task="token_classify",
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
@@ -1771,23 +1787,15 @@ class LLM:
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(params, len(seq_prompts))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
seq_priority = self._priority_to_seq(priority, len(prompts))
|
||||
|
||||
return self._render_and_add_requests(
|
||||
prompts=(
|
||||
self._preprocess_cmpl_one(prompt, tok_kwargs)
|
||||
for prompt, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
self._preprocess_cmpl_one(prompt, tokenization_kwargs)
|
||||
for prompt in maybe_tqdm(
|
||||
seq_prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering prompts",
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
@@ -1841,13 +1849,6 @@ class LLM:
|
||||
seq_convs = conversation_to_seq(messages)
|
||||
seq_params = self._params_to_seq(params, len(seq_convs))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
|
||||
return self._render_and_run_requests(
|
||||
prompts=(
|
||||
@@ -1859,16 +1860,13 @@ class LLM:
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
tokenization_kwargs=tok_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
for conversation, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
for conversation in maybe_tqdm(
|
||||
seq_convs,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering conversations",
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
|
||||
@@ -18,6 +18,20 @@ class RequestLogger:
|
||||
def __init__(self, *, max_log_len: int | None) -> None:
|
||||
self.max_log_len = max_log_len
|
||||
|
||||
if not logger.isEnabledFor(logging.INFO):
|
||||
logger.warning_once(
|
||||
"`--enable-log-requests` is set but "
|
||||
"the minimum log level is higher than INFO. "
|
||||
"No request information will be logged."
|
||||
)
|
||||
elif not logger.isEnabledFor(logging.DEBUG):
|
||||
logger.info_once(
|
||||
"`--enable-log-requests` is set but "
|
||||
"the minimum log level is higher than DEBUG. "
|
||||
"Only limited information will be logged to minimize overhead. "
|
||||
"To view more details, set `VLLM_LOGGING_LEVEL=DEBUG`."
|
||||
)
|
||||
|
||||
def log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
|
||||
@@ -38,6 +38,7 @@ from vllm.logprobs import Logprob
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RepetitionDetectionParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
@@ -336,6 +337,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
repetition_detection: RepetitionDetectionParams | None = Field(
|
||||
default=None,
|
||||
description="Parameters for detecting repetitive N-gram patterns "
|
||||
"in output tokens. If such repetition is detected, generation will "
|
||||
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
|
||||
"token patterns, stopping only when they hit the maximum output length "
|
||||
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
|
||||
"can detect such behavior and terminate early, saving time and tokens.",
|
||||
)
|
||||
|
||||
# --8<-- [end:chat-completion-extra-params]
|
||||
|
||||
def build_chat_params(
|
||||
@@ -490,7 +501,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
@@ -500,8 +510,37 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
repetition_detection=self.repetition_detection,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_response_format(cls, data):
|
||||
response_format = data.get("response_format")
|
||||
if response_format is None:
|
||||
return data
|
||||
|
||||
rf_type = (
|
||||
response_format.get("type")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "type", None)
|
||||
)
|
||||
|
||||
if rf_type == "json_schema":
|
||||
json_schema = (
|
||||
response_format.get("json_schema")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "json_schema", None)
|
||||
)
|
||||
if json_schema is None:
|
||||
raise VLLMValidationError(
|
||||
"When response_format type is 'json_schema', the "
|
||||
"'json_schema' field must be provided.",
|
||||
parameter="response_format",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
|
||||
@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
|
||||
# get the expected call based on partial JSON
|
||||
# parsing which "autocompletes" the JSON
|
||||
expected_call = json.dumps(
|
||||
tool_parser.prev_tool_call_arr[index].get(
|
||||
# parsing which "autocompletes" the JSON.
|
||||
# Tool parsers (e.g. Qwen3Coder) store
|
||||
# arguments as a JSON string in
|
||||
# prev_tool_call_arr. Calling json.dumps()
|
||||
# on an already-serialized string would
|
||||
# double-serialize it (e.g. '{"k":1}' becomes
|
||||
# '"{\\"k\\":1}"'), which then causes the
|
||||
# replace() below to fail and append the
|
||||
# entire double-serialized string as a
|
||||
# spurious final delta.
|
||||
args = tool_parser.prev_tool_call_arr[index].get(
|
||||
"arguments", {}
|
||||
),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if isinstance(args, str):
|
||||
expected_call = args
|
||||
else:
|
||||
expected_call = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
# get what we've streamed so far for arguments
|
||||
# for the current tool
|
||||
|
||||
@@ -143,7 +143,8 @@ class BaseFrontendArgs:
|
||||
templates and other tokenizer configuration."""
|
||||
enable_log_outputs: bool = False
|
||||
"""If set to True, log model outputs (generations).
|
||||
Requires --enable-log-requests."""
|
||||
Requires `--enable-log-requests`. As with `--enable-log-requests`,
|
||||
information is only logged at INFO level at maximum."""
|
||||
enable_log_deltas: bool = True
|
||||
"""If set to False, output deltas will not be logged. Relevant only if
|
||||
--enable-log-outputs is set.
|
||||
@@ -277,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs):
|
||||
Enable offline FastAPI documentation for air-gapped environments.
|
||||
Uses vendored static assets bundled with vLLM.
|
||||
"""
|
||||
use_gpu_for_pooling_score: bool = False
|
||||
"""If set, run pooling score MaxSim on GPU in the API server process.
|
||||
Can significantly improve late-interaction scoring performance.
|
||||
https://github.com/vllm-project/vllm/pull/35330"""
|
||||
|
||||
@classmethod
|
||||
def _customize_cli_kwargs(
|
||||
|
||||
@@ -26,6 +26,7 @@ from vllm.logprobs import Logprob
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sampling_params import (
|
||||
BeamSearchParams,
|
||||
RepetitionDetectionParams,
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
StructuredOutputsParams,
|
||||
@@ -166,6 +167,16 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
repetition_detection: RepetitionDetectionParams | None = Field(
|
||||
default=None,
|
||||
description="Parameters for detecting repetitive N-gram patterns "
|
||||
"in output tokens. If such repetition is detected, generation will "
|
||||
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
|
||||
"token patterns, stopping only when they hit the maximum output length "
|
||||
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
|
||||
"can detect such behavior and terminate early, saving time and tokens.",
|
||||
)
|
||||
|
||||
# --8<-- [end:completion-extra-params]
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
@@ -259,7 +270,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
structured_outputs_kwargs["json"] = json_schema.json_schema
|
||||
elif response_format.type == "structural_tag":
|
||||
structural_tag = response_format
|
||||
assert structural_tag is not None and isinstance(
|
||||
assert isinstance(
|
||||
structural_tag,
|
||||
(
|
||||
LegacyStructuralTagResponseFormat,
|
||||
@@ -302,7 +313,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
@@ -311,8 +321,37 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
allowed_token_ids=self.allowed_token_ids,
|
||||
extra_args=extra_args or None,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
repetition_detection=self.repetition_detection,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_response_format(cls, data):
|
||||
response_format = data.get("response_format")
|
||||
if response_format is None:
|
||||
return data
|
||||
|
||||
rf_type = (
|
||||
response_format.get("type")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "type", None)
|
||||
)
|
||||
|
||||
if rf_type == "json_schema":
|
||||
json_schema = (
|
||||
response_format.get("json_schema")
|
||||
if isinstance(response_format, dict)
|
||||
else getattr(response_format, "json_schema", None)
|
||||
)
|
||||
if json_schema is None:
|
||||
raise VLLMValidationError(
|
||||
"When response_format type is 'json_schema', the "
|
||||
"'json_schema' field must be provided.",
|
||||
parameter="response_format",
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_structured_outputs_count(cls, data):
|
||||
|
||||
@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.protocol import (
|
||||
ClassificationChatRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationResponse,
|
||||
)
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingBytesResponse,
|
||||
EmbeddingChatRequest,
|
||||
@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
|
||||
| TokenizeCompletionRequest
|
||||
| DetokenizeRequest
|
||||
| EmbeddingCompletionRequest
|
||||
| ClassificationCompletionRequest
|
||||
| RerankRequest
|
||||
| ScoreRequest
|
||||
| PoolingCompletionRequest
|
||||
@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
|
||||
ChatCompletionRequest
|
||||
| TokenizeChatRequest
|
||||
| EmbeddingChatRequest
|
||||
| ClassificationChatRequest
|
||||
| PoolingChatRequest
|
||||
)
|
||||
|
||||
@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
|
||||
| TranscriptionResponse
|
||||
| TokenizeResponse
|
||||
| PoolingResponse
|
||||
| ClassificationResponse
|
||||
| ScoreResponse
|
||||
| GenerateResponse
|
||||
)
|
||||
|
||||
|
||||
RequestT = TypeVar("RequestT", bound=AnyRequest)
|
||||
|
||||
|
||||
@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
|
||||
|
||||
class OpenAIServing:
|
||||
request_id_prefix: ClassVar[str] = """
|
||||
A short string prepended to every request’s ID (e.g. "embd", "classify")
|
||||
so you can easily tell “this ID came from Embedding vs Classification.”
|
||||
A short string prepended to every request’s ID (e.g. "embd")
|
||||
so you can easily tell “this ID came from Embedding.”
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -456,7 +447,7 @@ class OpenAIServing:
|
||||
) -> ErrorResponse | None:
|
||||
"""
|
||||
Default preprocessing hook. Subclasses may override
|
||||
to prepare `ctx` (classification, embedding, etc.).
|
||||
to prepare `ctx` (embedding, etc.).
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -817,7 +808,7 @@ class OpenAIServing:
|
||||
token_num = len(input_ids)
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
# Note: EmbeddingRequest, ClassificationRequest,
|
||||
# Note: EmbeddingRequest,
|
||||
# and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
@@ -828,8 +819,6 @@ class OpenAIServing:
|
||||
ScoreTextRequest,
|
||||
ScoreQueriesDocumentsRequest,
|
||||
RerankRequest,
|
||||
ClassificationCompletionRequest,
|
||||
ClassificationChatRequest,
|
||||
),
|
||||
):
|
||||
# Note: input length can be up to the entire model context length
|
||||
@@ -839,8 +828,6 @@ class OpenAIServing:
|
||||
ScoreDataRequest: "score",
|
||||
ScoreTextRequest: "score",
|
||||
ScoreQueriesDocumentsRequest: "score",
|
||||
ClassificationCompletionRequest: "classification",
|
||||
ClassificationChatRequest: "classification",
|
||||
}
|
||||
operation = operations.get(type(request), "embedding generation")
|
||||
raise VLLMValidationError(
|
||||
|
||||
@@ -328,8 +328,9 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
# Also check text.format for OpenAI-style json_schema
|
||||
if self.text is not None and self.text.format is not None:
|
||||
if structured_outputs is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both structured_outputs and text.format"
|
||||
raise VLLMValidationError(
|
||||
"Cannot specify both structured_outputs and text.format",
|
||||
parameter="structured_outputs",
|
||||
)
|
||||
response_format = self.text.format
|
||||
if (
|
||||
@@ -378,14 +379,19 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_background(cls, data):
|
||||
if not data.get("background"):
|
||||
return data
|
||||
if not data.get("store", True):
|
||||
raise ValueError("background can only be used when `store` is true")
|
||||
raise VLLMValidationError(
|
||||
"background can only be used when `store` is true",
|
||||
parameter="background",
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_prompt(cls, data):
|
||||
if data.get("prompt") is not None:
|
||||
raise VLLMValidationError(
|
||||
@@ -394,16 +400,19 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_cache_salt_support(cls, data):
|
||||
if data.get("cache_salt") is not None and (
|
||||
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
|
||||
):
|
||||
raise ValueError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided."
|
||||
raise VLLMValidationError(
|
||||
"Parameter 'cache_salt' must be a non-empty string if provided.",
|
||||
parameter="cache_salt",
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def function_call_parsing(cls, data):
|
||||
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
|
||||
This ensures Pydantic can properly resolve union types in the input field.
|
||||
|
||||
@@ -85,6 +85,8 @@ from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseInputOutputMessage,
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseUsage,
|
||||
@@ -1339,6 +1341,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningPartAddedEvent(
|
||||
type="response.reasoning_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
content_index=current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text="",
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
@@ -1369,7 +1384,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
first_delta_sent = True
|
||||
# todo(kebe7jun) tool call support
|
||||
|
||||
@@ -1397,6 +1411,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
text=reason_content,
|
||||
)
|
||||
)
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningPartDoneEvent(
|
||||
type="response.reasoning_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text=reason_content,
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
current_content_index = 0
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
@@ -1418,6 +1445,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
item=reasoning_item,
|
||||
)
|
||||
)
|
||||
current_output_index += 1
|
||||
current_item_id = str(uuid.uuid4())
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemAddedEvent(
|
||||
type="response.output_item.added",
|
||||
@@ -1432,8 +1461,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
current_output_index += 1
|
||||
current_item_id = str(uuid.uuid4())
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
@@ -1449,7 +1476,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
# reset previous delta messages
|
||||
previous_delta_messages = []
|
||||
|
||||
@@ -1485,7 +1511,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
),
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
|
||||
previous_delta_messages.append(delta_message)
|
||||
if previous_delta_messages:
|
||||
@@ -1505,7 +1530,19 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
text=reason_content,
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningPartDoneEvent(
|
||||
type="response.reasoning_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=ResponseReasoningTextContent(
|
||||
text=reason_content,
|
||||
type="reasoning_text",
|
||||
),
|
||||
)
|
||||
)
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
content=[
|
||||
@@ -1543,7 +1580,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
item_id=current_item_id,
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
part = ResponseOutputText(
|
||||
text=final_content,
|
||||
type="output_text",
|
||||
@@ -1559,7 +1595,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
part=part,
|
||||
)
|
||||
)
|
||||
current_content_index += 1
|
||||
item = ResponseOutputMessage(
|
||||
type="message",
|
||||
role="assistant",
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from soundfile import LibsndfileError
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -57,6 +58,14 @@ try:
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
|
||||
# being librosa's main backend. Used to validate if an audio loading error is due to a
|
||||
# server error vs a client error (invalid audio file).
|
||||
# 1 = unrecognised format (file is not a supported audio container)
|
||||
# 3 = malformed file (corrupt or structurally invalid audio)
|
||||
# 4 = unsupported encoding (codec not supported by this libsndfile build)
|
||||
_BAD_SF_CODES = {1, 3, 4}
|
||||
|
||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||
SpeechToTextResponseVerbose: TypeAlias = (
|
||||
TranscriptionResponseVerbose | TranslationResponseVerbose
|
||||
@@ -315,9 +324,15 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
)
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
try:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
except LibsndfileError as exc:
|
||||
# Distinguish client errors (invalid audio) from server errors
|
||||
if exc.code in _BAD_SF_CODES:
|
||||
raise ValueError("Invalid or unsupported audio file.") from exc
|
||||
raise
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"The 'vllm.entrypoints.openai.translations' module has been renamed to "
|
||||
"'vllm.entrypoints.openai.speech_to_text'. Please update your imports. "
|
||||
"This backward-compatible alias will be removed in version 0.17+.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user