Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -1,11 +1,12 @@
FROM registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8
ARG BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1
FROM ${BASE_IMAGE}
# Keep the runtime stack from the known-good v8 image, but replace the
# installed Python package with the repository's patched 0.16.1rc0 sources.
WORKDIR /tmp
WORKDIR /home
RUN rm -rf /usr/local/lib/python3.12/dist-packages/vllm \
/usr/local/lib/python3.12/dist-packages/vllm-*.dist-info
COPY vllm /usr/local/lib/python3.12/dist-packages/vllm
COPY vllm-0.16.1rc0+corex.4.4.0.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.16.1rc0+corex.4.4.0.dist-info
COPY vllm-0.17.0+corex.20260420090923.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.17.0+corex.20260420090923.dist-info
ENTRYPOINT ["/bin/bash"]

View File

@@ -1,62 +1,37 @@
# bi_150-vllm
基于 `registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8`
`vLLM 0.16.1rc0` 构建仓库,用于在 BI-V150 虚拟机环境中生成可直接运行的镜像。
This repository contains the extracted `vLLM 0.17.0+corex.20260420090923`
Python package used to overlay the vendor image
`registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1`.
## 改动说明
本仓库只保留构建镜像所需的最小内容:
## Included files
- `vllm/`
当前运行代码
- `vllm-0.16.1rc0+corex.4.4.0.dist-info/`
对应的包元数据
The Python package code copied from the image payload.
- `vllm-0.17.0+corex.20260420090923.dist-info/`
The package metadata extracted from the image.
- `Dockerfile`
构建最终镜像
Builds a new image by replacing the installed `vllm` package in the vendor base image.
与基础镜像相比,本仓库保留的关键代码改动如下:
## Build
-`vllm/platforms/__init__.py` 中修复 CUDA 平台识别逻辑
- 当 NVML 不可用且出现 `NVML Shared Library Not Found` 一类错误时
不再直接判定为非 CUDA 平台
- 改为回退到 `torch.cuda.is_available()`
`torch.cuda.device_count()` 继续判断 CUDA 是否可用
- 调整 CLI 初始化逻辑,避免 benchmark 可选依赖缺失时阻塞
`vllm serve ...` 启动
这个修复用于解决如下启动失败:
```text
RuntimeError: Failed to infer device type
```
## 构建镜像
在仓库根目录执行:
Run the following command from the repository root:
```bash
docker build -t bi_150_vllm:0.16.1 .
docker build --pull=false \
--build-arg BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1 \
-t bi_150_vllm:0.17.0 \
.
```
## 启动镜像
## Verify
```bash
docker run -dit \
--name iluvatar_test \
-p 38047:8000 \
--privileged \
-v /lib/modules:/lib/modules \
-v /dev:/dev \
-v /usr/src:/usr/src \
-v /mnt/gpfs/leaderboard/modelHubXC/Amu/t1-1.5B:/model \
-e CUDA_VISIBLE_DEVICES=0 \
--entrypoint vllm \
bi_150_vllm:0.16.1 \
serve /model \
--port 8000 \
--served-model-name llm \
--max-model-len 2048 \
--enforce-eager \
--trust-remote-code \
-tp 1
docker run --rm -it bi_150_vllm:0.17.0 \
python3 -c "import vllm; print(vllm.__file__); print(vllm.__version__)"
```
## Notes
- This is an overlay-style repository, not the original upstream git source tree.
- The Docker image keeps the vendor runtime stack and only replaces the Python package files.

View File

@@ -1 +0,0 @@
{"archive_info": {"hash": "sha256=f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1", "hashes": {"sha256": "f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1"}}, "url": "file:///workspace/vllm-0.16.1rc0%2Bcorex.4.4.0-py3-none-any.whl"}

View File

@@ -1,9 +1,9 @@
Metadata-Version: 2.4
Name: vllm
Version: 0.16.1rc0+corex.4.4.0
Version: 0.17.0+corex.20260420090923
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
Author: vLLM Team
License-Expression: Apache-2.0
License: Apache-2.0
Project-URL: Homepage, https://github.com/vllm-project/vllm
Project-URL: Documentation, https://docs.vllm.ai/en/latest/
Project-URL: Slack, https://slack.vllm.ai/
@@ -23,7 +23,7 @@ Requires-Dist: regex
Requires-Dist: cachetools
Requires-Dist: psutil
Requires-Dist: sentencepiece
Requires-Dist: numpy==1.26.4
Requires-Dist: numpy
Requires-Dist: requests>=2.26.0
Requires-Dist: tqdm
Requires-Dist: blake3
@@ -33,7 +33,7 @@ Requires-Dist: tokenizers>=0.21.1
Requires-Dist: protobuf!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*,>=5.29.6
Requires-Dist: fastapi[standard]>=0.115.0
Requires-Dist: aiohttp>=3.13.3
Requires-Dist: openai>=1.99.1
Requires-Dist: openai<2.25.0,>=1.99.1
Requires-Dist: pydantic>=2.12.0
Requires-Dist: prometheus_client>=0.18.0
Requires-Dist: pillow
@@ -52,6 +52,7 @@ Requires-Dist: pyzmq>=25.0.0
Requires-Dist: msgspec
Requires-Dist: gguf>=0.17.0
Requires-Dist: mistral_common[image]>=1.9.1
Requires-Dist: opencv-python-headless>=4.13.0
Requires-Dist: pyyaml
Requires-Dist: six>=1.16.0; python_version > "3.11"
Requires-Dist: setuptools<81.0.0,>=77.0.3; python_version > "3.11"
@@ -76,6 +77,7 @@ Requires-Dist: opentelemetry-sdk>=1.27.0
Requires-Dist: opentelemetry-api>=1.27.0
Requires-Dist: opentelemetry-exporter-otlp>=1.27.0
Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1
Requires-Dist: kaldi-native-fbank>=1.18.7
Requires-Dist: numba==0.61.2
Requires-Dist: ray[cgraph]>=2.48.0
Provides-Extra: bench
@@ -84,6 +86,7 @@ Requires-Dist: matplotlib; extra == "bench"
Requires-Dist: seaborn; extra == "bench"
Requires-Dist: datasets; extra == "bench"
Requires-Dist: scipy; extra == "bench"
Requires-Dist: plotly; extra == "bench"
Provides-Extra: tensorizer
Requires-Dist: tensorizer==2.10.1; extra == "tensorizer"
Provides-Extra: fastsafetensors
@@ -97,6 +100,10 @@ Requires-Dist: soundfile; extra == "audio"
Requires-Dist: mistral_common[audio]; extra == "audio"
Provides-Extra: video
Provides-Extra: flashinfer
Provides-Extra: petit-kernel
Requires-Dist: petit-kernel; extra == "petit-kernel"
Provides-Extra: helion
Requires-Dist: helion; extra == "helion"
Provides-Extra: otel
Requires-Dist: opentelemetry-sdk>=1.26.0; extra == "otel"
Requires-Dist: opentelemetry-api>=1.26.0; extra == "otel"

View File

@@ -1,5 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (82.0.0)
Generator: setuptools (80.10.2)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1 @@
{"archive_info": {"hash": "sha256=844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740", "hashes": {"sha256": "844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740"}}, "url": "file:///home/poweruser/zrl/code/vllm_hub/vllm/build_pip/vllm-0.17.0%2Bcorex.20260420090923-py3-none-any.whl"}

244
vllm/.gitignore vendored
View File

@@ -1,244 +0,0 @@
# version file generated by setuptools-scm
/vllm/_version.py
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/flash_attn_interface.py
# OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/*
# FlashMLA interface copied from source
vllm/third_party/flashmla/flash_mla_interface.py
# triton jit
.triton
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
/.deps/
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# generated files
**/generated/**
# uv
uv.lock
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
docs/argparse
docs/examples/*
!docs/examples/README.md
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# VSCode
.vscode/
# Claude
.claude/
# Codex
.codex/
# Cursor
.cursor/
# DS Store
.DS_Store
# Results
*.csv
# Python pickle files
*.pkl
# Sphinx documentation
_build/
# vim swap files
*.swo
*.swp
# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h
# Benchmark dataset
benchmarks/**/*.json
# Linting
actionlint
shellcheck*/
# Ignore moe/marlin_moe gen code
csrc/moe/marlin_moe_wna16/kernel_*
# Ignore ep_kernels_workspace folder
ep_kernels_workspace/
# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
!vllm/benchmarks/lib/
# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
vllm/grpc/vllm_engine_pb2.py
vllm/grpc/vllm_engine_pb2_grpc.py
vllm/grpc/vllm_engine_pb2.pyi
# Ignore generated cpu headers
csrc/cpu/cpu_attn_dispatch_generated.h

View File

@@ -14,8 +14,6 @@ import typing
import vllm.env_override # noqa: F401
MODULE_ATTRS = {
"bc_linter_skip": "._bc_linter:bc_linter_skip",
"bc_linter_include": "._bc_linter:bc_linter_include",
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
@@ -62,8 +60,6 @@ if typing.TYPE_CHECKING:
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.executor.ray_utils import initialize_ray_cluster
from ._bc_linter import bc_linter_include, bc_linter_skip
else:
def __getattr__(name: str) -> typing.Any:
@@ -79,8 +75,6 @@ else:
__all__ = [
"__version__",
"bc_linter_skip",
"bc_linter_include",
"__version_tuple__",
"LLM",
"ModelRegistry",

View File

@@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
@@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale,
num_local_tokens=num_local_tokens,
dtype=output_dtype,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
)
@@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
def _rocm_aiter_fused_topk_impl(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.fused_moe import fused_topk
# fused_topk returns (topk_weights, topk_indices)
return fused_topk(x, router_logits, top_k, gate_up)
def _rocm_aiter_fused_topk_fake(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> None:
# tuple[torch.Tensor, torch.Tensor]:
pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
@@ -994,6 +1024,70 @@ class rocm_aiter_ops:
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
@staticmethod
def get_aiter_activation_type(activation_str: str):
"""
Given an activation type as a string, returns the corresponding aiter ActivationType enum.
Supported activation types: "no", "none", "silu", "gelu", "swiglu".
Returns None if the mapping fails.
Args:
activation_str (str): Activation type as string.
Returns:
Aiter ActivationType enum value, or None if not found.
"""
# Import only locally, since aiter may not always be available.
try:
from aiter import ActivationType
except ImportError:
return None
if not isinstance(activation_str, str):
return None
name = activation_str.strip().lower()
mapping = {
"none": ActivationType.No,
"no": ActivationType.No,
"silu": ActivationType.Silu,
"gelu": ActivationType.Gelu,
"swiglu": ActivationType.Swiglu,
}
return mapping.get(name)
@staticmethod
def get_aiter_quant_type(quant_type_str: str):
"""
Given a quantization type as a string, returns the corresponding aiter QuantType enum.
Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128".
Returns None if the mapping fails.
Args:
quant_type_str (str): Quantization type as string.
Returns:
Aiter QuantType enum value, or None if not found.
"""
try:
from aiter import QuantType
except ImportError:
return None
if not isinstance(quant_type_str, str):
return None
name = quant_type_str.strip().lower()
mapping = {
"no": QuantType.No,
"per_tensor": QuantType.per_Tensor,
"per_token": QuantType.per_Token,
"per_1x32": QuantType.per_1x32,
"per_1x128": QuantType.per_1x128,
"per_128x128": QuantType.per_128x128,
}
return mapping.get(name)
@classmethod
@if_aiter_supported
def is_enabled(cls) -> bool:
@@ -1127,6 +1221,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_fused_topk",
op_func=_rocm_aiter_fused_topk_impl,
mutates_args=[],
fake_impl=_rocm_aiter_fused_topk_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=_rocm_aiter_mla_decode_fwd_impl,
@@ -1360,6 +1462,10 @@ class rocm_aiter_ops:
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
hidden_pad: int = 0,
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
@@ -1377,6 +1483,10 @@ class rocm_aiter_ops:
a2_scale,
num_local_tokens,
output_dtype,
hidden_pad,
intermediate_pad,
bias1,
bias2,
)
@staticmethod
@@ -1481,6 +1591,15 @@ class rocm_aiter_ops:
routed_scaling_factor,
)
@staticmethod
def fused_topk(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
gate_up: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up)
@staticmethod
def mla_decode_fwd(
q: torch.Tensor,
@@ -1701,6 +1820,47 @@ class rocm_aiter_ops:
return shuffle_weight(tensor, layout=layout)
@staticmethod
def shuffle_weight_a16w4(
tensor: "torch.Tensor",
nLane: int,
gate_up: bool,
) -> "torch.Tensor":
"""
Shuffles the weight tensor into (A16W4) layout for AITER kernels.
Args:
tensor: The input weight tensor to be shuffled.
layout: The block layout to use, defaults to (16, 4).
Returns:
torch.Tensor: The shuffled tensor.
"""
from aiter.ops.shuffle import shuffle_weight_a16w4
return shuffle_weight_a16w4(tensor, nLane, gate_up)
@staticmethod
def shuffle_scale_a16w4(
tensor: "torch.Tensor",
num_experts: int,
gate_up: bool,
) -> "torch.Tensor":
"""
Shuffles the scale tensor into (A16W4) layout for AITER kernels.
Args:
tensor: The input scale tensor to be shuffled.
num_experts: Number of experts, needed for reshaping logic.
gate_up: Whether the scale is for w13 (True) or w2 (False).
Returns:
torch.Tensor: The shuffled scale tensor.
"""
from aiter.ops.shuffle import shuffle_scale_a16w4
return shuffle_scale_a16w4(tensor, num_experts, gate_up)
@staticmethod
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)

View File

@@ -1,54 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# vllm/_bc_linter.py
from collections.abc import Callable
from typing import Any, TypeVar, overload
T = TypeVar("T")
@overload
def bc_linter_skip(obj: T) -> T: ...
@overload
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ...
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
"""
No-op decorator to mark symbols/files for BC-linter suppression.
Usage:
@bc_linter_skip
def legacy_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
@overload
def bc_linter_include(obj: T) -> T: ...
@overload
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ...
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
"""
Usage:
@bc_linter_include
def public_api(...): ...
"""
def _wrap(x: T) -> T:
return x
return _wrap if obj is None else obj
__all__ = ["bc_linter_skip", "bc_linter_include"]

File diff suppressed because it is too large Load Diff

View File

@@ -31,6 +31,7 @@ from tempfile import NamedTemporaryFile
from typing import Any, cast
import numpy as np
from huggingface_hub import snapshot_download
from PIL import Image
from typing_extensions import deprecated
@@ -60,6 +61,8 @@ except ImportError:
logger = logging.getLogger(__name__)
DEFAULT_NUM_PROMPTS = 1000
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------
@@ -303,9 +306,11 @@ def process_image(image: Any) -> Mapping[str, Any]:
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
a dictionary with the image as a base64 data URL.
3. String input: - Treats the string as a URL or local file path. -
Prepends "file://" if the string doesn't start with "http://" or
"file://". - Returns a dictionary with the image URL.
3. String input: - Treats the string as a URL, local file path, or base64
encoded data. - If string starts with "data:image/", treats as base64.
- If string starts with "http://", "https://", or "file://", treats as URL.
- Otherwise treats as local file path and prepends "file://".
- Returns a dictionary with the image URL or base64 data.
Raises:
ValueError: If the input is not a supported type.
@@ -325,14 +330,14 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, str):
image_url = (
image
if image.startswith(("http://", "https://", "file://"))
if image.startswith(("http://", "https://", "file://", "data:image/"))
else f"file://{image}"
)
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(
f"Invalid image input {image}. Must be a PIL.Image.Image"
" or str or dictionary with raw image bytes."
f"Invalid image input {image}. Must be a PIL.Image.Image, "
"str (URL, file path, or base64 data URL), or dictionary with raw image bytes."
)
@@ -1338,7 +1343,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"--num-prompts",
type=int,
default=1000,
default=DEFAULT_NUM_PROMPTS,
help="Number of prompts to process.",
)
parser.add_argument(
@@ -2676,6 +2681,14 @@ class MMVUDataset(HuggingFaceDataset):
+ (" ".join(f"{k}.{v}" for k, v in x["choices"].items())),
}
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self._remote_path_root = (
f"https://huggingface.co/datasets/{self.hf_name}/resolve/main"
)
self._local_path_root = snapshot_download(self.hf_name, repo_type="dataset")
def sample(
self,
tokenizer: TokenizerLike,
@@ -2698,7 +2711,9 @@ class MMVUDataset(HuggingFaceDataset):
break
prompt = parser_fn(item)
mm_content = process_video(item["video"])
mm_content = process_video(
item["video"].replace(self._remote_path_root, self._local_path_root)
)
prompt_len = len(tokenizer.encode(prompt))
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Benchmark library utilities."""

View File

@@ -0,0 +1,802 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""The request function for API endpoints."""
import io
import json
import os
import sys
import time
import traceback
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import Any, Literal, Protocol
import aiohttp
import regex as re
from tqdm.asyncio import tqdm
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
class StreamedResponseHandler:
"""Handles streaming HTTP responses by accumulating chunks until complete
messages are available."""
def __init__(self):
self.buffer = ""
def add_chunk(self, chunk_bytes: bytes) -> list[str]:
"""Add a chunk of bytes to the buffer and return any complete
messages."""
chunk_str = chunk_bytes.decode("utf-8")
self.buffer += chunk_str
messages = []
# Split by double newlines (SSE message separator)
while "\n\n" in self.buffer:
message, self.buffer = self.buffer.split("\n\n", 1)
message = message.strip()
if message:
messages.append(message)
# if self.buffer is not empty, check if it is a complete message
# by removing data: prefix and check if it is a valid JSON
if self.buffer.startswith("data: "):
message_content = self.buffer.removeprefix("data: ").strip()
if message_content == "[DONE]":
messages.append(self.buffer.strip())
self.buffer = ""
elif message_content:
try:
json.loads(message_content)
messages.append(self.buffer.strip())
self.buffer = ""
except json.JSONDecodeError:
# Incomplete JSON, wait for more chunks.
pass
return messages
@dataclass
class RequestFuncInput:
"""The input for the request function."""
prompt: str | list[str]
api_url: str
prompt_len: int
output_len: int
model: str
model_name: str | None = None
logprobs: int | None = None
extra_headers: dict | None = None
extra_body: dict | None = None
multi_modal_content: dict | list[dict] | None = None
ignore_eos: bool = False
language: str | None = None
request_id: str | None = None
@dataclass
class RequestFuncOutput:
"""The output of the request function including metrics."""
generated_text: str = ""
success: bool = False
latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token
itl: list[float] = field(default_factory=list) # list of inter-token latencies
tpot: float = 0.0 # avg next-token latencies
prompt_len: int = 0
error: str = ""
start_time: float = 0.0
input_audio_duration: float = 0.0 # in seconds
class RequestFunc(Protocol):
def __call__(
self,
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> Awaitable[RequestFuncOutput]: ...
def _validate_api_url(
api_url: str,
api_name: str,
expected_suffixes: str | set[str],
) -> None:
if isinstance(expected_suffixes, str):
expected_suffixes = {expected_suffixes}
expected_suffixes = {*expected_suffixes, "profile"}
if not api_url.endswith(tuple(expected_suffixes)):
raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
def _update_payload_common(
payload: dict[str, Any],
request_func_input: RequestFuncInput,
) -> None:
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
def _update_headers_common(
headers: dict[str, Any],
request_func_input: RequestFuncInput,
) -> None:
if request_func_input.extra_headers:
headers |= request_func_input.extra_headers
if request_func_input.request_id:
headers["x-request-id"] = request_func_input.request_id
def _get_headers(content_type: str | None = None) -> dict[str, str]:
headers = {}
if content_type:
headers["Content-Type"] = content_type
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
"""The async request function for the OpenAI Completions API.
Args:
request_func_input: The input for the request function.
pbar: The progress bar to display the progress.
Returns:
The output of the request function.
"""
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Completions API", "completions")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"prompt": request_func_input.prompt,
"repetition_penalty": 1.0,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
_update_payload_common(payload, request_func_input)
headers = _get_headers()
_update_headers_common(headers, request_func_input)
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
first_chunk_received = False
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if message.startswith(":"):
continue
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
data = json.loads(chunk)
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# want to check a token was generated
if choices := data.get("choices"):
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter()
# First token
if not first_chunk_received:
first_chunk_received = True
ttft = time.perf_counter() - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
most_recent_timestamp = timestamp
generated_text += text or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
if first_chunk_received:
output.success = True
else:
output.success = False
output.error = (
"Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!"
)
output.generated_text = generated_text
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
def _get_chat_content(
request_func_input: RequestFuncInput,
mm_position: Literal["first", "last"] = "last",
) -> list[dict[str, Any]]:
text_contents = [{"type": "text", "text": request_func_input.prompt}]
mm_contents = []
if request_func_input.multi_modal_content:
mm_content = request_func_input.multi_modal_content
if isinstance(mm_content, list):
mm_contents.extend(request_func_input.multi_modal_content)
elif isinstance(mm_content, dict):
mm_contents.append(request_func_input.multi_modal_content)
else:
raise TypeError(
"multi_modal_content must be a dict or list[dict] for openai-chat"
)
if mm_position == "first":
return mm_contents + text_contents
return text_contents + mm_contents
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
mm_position: Literal["first", "last"] = "last",
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"messages": [
{"role": "user", "content": content},
],
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"stream_options": {
"include_usage": True,
},
}
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(url=api_url, json=payload, headers=headers) as response:
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
# NOTE: SSE comments (often used as pings) start with
# a colon. These are not JSON data payload and should
# be skipped.
if message.startswith(":"):
continue
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get("completion_tokens")
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
import soundfile
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"max_completion_tokens": request_func_input.output_len,
"stream": True,
"language": "en",
# Flattened due to multipart/form-data
"stream_include_usage": True,
"stream_continuous_usage_stats": True,
}
_update_payload_common(payload, request_func_input)
headers = _get_headers()
_update_headers_common(headers, request_func_input)
# Send audio file
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer
mm_audio = request_func_input.multi_modal_content
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
raise TypeError("multi_modal_content must be a dict containing 'audio'")
with to_bytes(*mm_audio["audio"]) as f:
form = aiohttp.FormData()
form.add_field("file", f, content_type="audio/wav")
for key, value in payload.items():
form.add_field(key, str(value))
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
output.input_audio_duration = soundfile.info(f).duration
f.seek(0)
generated_text = ""
ttft = 0.0
st = time.perf_counter()
output.start_time = st
most_recent_timestamp = st
try:
async with session.post(
url=api_url, data=form, headers=headers
) as response:
if response.status == 200:
handler = StreamedResponseHandler()
async for chunk_bytes in response.content.iter_any():
chunk_bytes = chunk_bytes.strip()
if not chunk_bytes:
continue
messages = handler.add_chunk(chunk_bytes)
for message in messages:
if type(message) is bytes:
message = message.decode("utf-8")
chunk = message.removeprefix("data: ")
if chunk != "[DONE]":
timestamp = time.perf_counter()
data = json.loads(chunk)
if choices := data.get("choices"):
content = choices[0]["delta"].get("content")
# First token
if ttft == 0.0:
ttft = timestamp - st
output.ttft = ttft
# Decoding phase
else:
output.itl.append(
timestamp - most_recent_timestamp
)
generated_text += content or ""
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens"
)
most_recent_timestamp = timestamp
output.generated_text = generated_text
output.success = True
output.latency = most_recent_timestamp - st
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
if pbar:
pbar.update(1)
return output
async def _run_pooling_request(
session: aiohttp.ClientSession,
api_url: str,
payload: dict[str, Any],
headers: dict[str, Any],
pbar: tqdm | None = None,
) -> RequestFuncOutput:
output = RequestFuncOutput()
st = time.perf_counter()
output.start_time = st
try:
async with session.post(url=api_url, headers=headers, json=payload) as response:
if response.status == 200:
output.ttft = output.latency = time.perf_counter() - st
if payload.get("encoding_format", "float") == "bytes":
metadata = json.loads(response.headers["metadata"])
usage = metadata.get("usage", {})
else:
data = await response.json()
usage = data.get("usage", {})
output.success = True
output.generated_text = ""
output.prompt_len = usage.get("prompt_tokens", 0)
else:
output.success = False
output.error = response.reason or ""
except Exception as e:
output.success = False
output.error = str(e)
if pbar:
pbar.update(1)
return output
async def async_request_openai_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"input": request_func_input.prompt,
# Many embedding models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_vllm_rerank(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "vLLM score API", "rerank")
assert (
isinstance(request_func_input.prompt, list)
and len(request_func_input.prompt) > 1
)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"query": request_func_input.prompt[0],
"documents": request_func_input.prompt[1:],
# Many reranker models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_openai_embeddings_chat(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
mm_position: Literal["first", "last"] = "last",
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
content = _get_chat_content(request_func_input, mm_position=mm_position)
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"messages": [
{"role": "user", "content": content},
],
# Many embedding models have short context length,
# this is to avoid dropping some of the requests.
"truncate_prompt_tokens": -1,
}
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
def _try_extract_request_idx(request_func_input: RequestFuncInput):
if request_func_input.request_id:
match = re.search(r"(\d+)$", request_func_input.request_id)
if match:
try:
return int(match.group(1))
except ValueError:
pass
return None
def _preprocess_clip(request_func_input: RequestFuncInput):
if request_func_input.multi_modal_content:
# Image input
request_func_input.prompt = ""
def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
if request_func_input.multi_modal_content:
request_idx = _try_extract_request_idx(request_func_input)
# Adjust the ratio manually if needed.
use_image_only_prompt = request_idx is None or request_idx % 2 == 0
if use_image_only_prompt:
# Image input
request_func_input.prompt = "Represent the given image."
else:
# Text+Image input
request_func_input.prompt = (
f"Represent the given image with the following question: "
f"{request_func_input.prompt}"
)
async def async_request_openai_embeddings_clip(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_clip(request_func_input)
return await async_request_openai_embeddings_chat(
request_func_input,
session,
pbar=pbar,
)
async def async_request_openai_embeddings_vlm2vec(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_vlm2vec(request_func_input)
return await async_request_openai_embeddings_chat(
request_func_input,
session,
pbar=pbar,
mm_position="first",
)
async def async_request_infinity_embeddings(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
}
if request_func_input.prompt:
payload["input"] = request_func_input.prompt
else:
mm_content = request_func_input.multi_modal_content
assert isinstance(mm_content, dict)
mm_type = mm_content["type"]
payload["input"] = mm_content[mm_type]["url"]
payload["modality"] = mm_type.split("_", 1)[0]
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
async def async_request_infinity_embeddings_clip(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
_preprocess_clip(request_func_input)
return await async_request_infinity_embeddings(
request_func_input,
session,
pbar=pbar,
)
async def async_request_vllm_pooling(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: tqdm | None = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
_validate_api_url(api_url, "vLLM Pooling API", "pooling")
payload = {
"model": request_func_input.model_name
if request_func_input.model_name
else request_func_input.model,
"truncate_prompt_tokens": -1,
}
payload = payload | request_func_input.prompt
_update_payload_common(payload, request_func_input)
headers = _get_headers("application/json")
_update_headers_common(headers, request_func_input)
return await _run_pooling_request(
session,
api_url,
payload=payload,
headers=headers,
pbar=pbar,
)
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
"vllm": async_request_openai_completions,
"openai": async_request_openai_completions,
"openai-chat": async_request_openai_chat_completions,
"openai-audio": async_request_openai_audio,
"openai-embeddings": async_request_openai_embeddings,
"openai-embeddings-chat": async_request_openai_embeddings_chat,
"openai-embeddings-clip": async_request_openai_embeddings_clip,
"openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
# Infinity embedding server: https://github.com/michaelfeil/infinity
"infinity-embeddings": async_request_infinity_embeddings,
"infinity-embeddings-clip": async_request_infinity_embeddings_clip,
# (Infinity embedding server does not support vlm2vec)
"vllm-pooling": async_request_vllm_pooling,
"vllm-rerank": async_request_vllm_rerank,
}
OPENAI_COMPATIBLE_BACKENDS = [
k
for k, v in ASYNC_REQUEST_FUNCS.items()
if v in (async_request_openai_completions, async_request_openai_chat_completions)
]

View File

@@ -0,0 +1,79 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for checking endpoint readiness."""
import asyncio
import time
import aiohttp
from tqdm.asyncio import tqdm
from vllm.logger import init_logger
from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
logger = init_logger(__name__)
async def wait_for_endpoint(
request_func: RequestFunc,
test_input: RequestFuncInput,
session: aiohttp.ClientSession,
timeout_seconds: int = 600,
retry_interval: int = 5,
) -> RequestFuncOutput:
"""
Wait for an endpoint to become available before starting benchmarks.
Args:
request_func: The async request function to call
test_input: The RequestFuncInput to test with
timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
retry_interval: Time between retries in seconds (default: 5 seconds)
Returns:
RequestFuncOutput: The successful response
Raises:
ValueError: If the endpoint doesn't become available within the timeout
"""
deadline = time.perf_counter() + timeout_seconds
output = RequestFuncOutput(success=False)
print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
with tqdm(
total=timeout_seconds,
bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
unit="s",
) as pbar:
while True:
# update progress bar
remaining = deadline - time.perf_counter()
elapsed = timeout_seconds - remaining
update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
pbar.update(update_amount)
pbar.refresh()
if remaining <= 0:
pbar.close()
break
# ping the endpoint using request_func
try:
output = await request_func(
request_func_input=test_input, session=session
)
if output.success:
pbar.close()
return output
else:
err_last_line = str(output.error).rstrip().rsplit("\n", 1)[-1]
logger.warning("Endpoint is not ready. Error='%s'", err_last_line)
except aiohttp.ClientConnectorError:
pass
# retry after a delay
sleep_duration = min(retry_interval, remaining)
if sleep_duration > 0:
await asyncio.sleep(sleep_duration)
return output

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import math
import os
from contextlib import contextmanager
from typing import Any
def extract_field(
args: argparse.Namespace, extra_info: dict[str, Any], field_name: str
) -> str:
if field_name in extra_info:
return extra_info[field_name]
v = args
# For example, args.compilation_config.mode
for nested_field in field_name.split("."):
if not hasattr(v, nested_field):
return ""
v = getattr(v, nested_field)
return v
def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool:
"""
Check if the benchmark is run with torch.compile
"""
return not (
extract_field(args, extra_info, "compilation_config.mode") == "0"
or "eager" in getattr(args, "output_json", "")
or "eager" in getattr(args, "result_filename", "")
)
def convert_to_pytorch_benchmark_format(
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
) -> list:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
records = []
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
return records
for name, benchmark_values in metrics.items():
if not isinstance(benchmark_values, list):
raise TypeError(
f"benchmark_values for metric '{name}' must be a list, "
f"but got {type(benchmark_values).__name__}"
)
record = {
"benchmark": {
"name": "vLLM benchmark",
"extra_info": {
"args": vars(args),
"compilation_config.mode": extract_field(
args, extra_info, "compilation_config.mode"
),
"optimization_level": extract_field(
args, extra_info, "optimization_level"
),
# A boolean field used by vLLM benchmark HUD dashboard
"use_compile": use_compile(args, extra_info),
},
},
"model": {
"name": args.model,
},
"metric": {
"name": name,
"benchmark_values": benchmark_values,
"extra_info": extra_info,
},
}
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
# Save tensor_parallel_size parameter if it's part of the metadata
if not tp and "tensor_parallel_size" in extra_info:
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
extra_info["tensor_parallel_size"]
)
records.append(record)
return records
class InfEncoder(json.JSONEncoder):
def clear_inf(self, o: Any):
if isinstance(o, dict):
return {
str(k)
if not isinstance(k, (str, int, float, bool, type(None)))
else k: self.clear_inf(v)
for k, v in o.items()
}
elif isinstance(o, list):
return [self.clear_inf(v) for v in o]
elif isinstance(o, float) and math.isinf(o):
return "inf"
return o
def iterencode(self, o: Any, *args, **kwargs) -> Any:
return super().iterencode(self.clear_inf(o), *args, **kwargs)
def write_to_json(filename: str, records: list) -> None:
with open(filename, "w") as f:
json.dump(
records,
f,
cls=InfEncoder,
default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
)
@contextmanager
def default_vllm_config():
"""Set a default VllmConfig for cases that directly test CustomOps or pathways
that use get_current_vllm_config() outside of a full engine context.
"""
from vllm.config import VllmConfig, set_current_vllm_config
with set_current_vllm_config(VllmConfig()):
yield

316
vllm/benchmarks/plot.py Normal file
View File

@@ -0,0 +1,316 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Generate plots for benchmark results."""
from pathlib import Path
from typing import Any
from vllm.utils.import_utils import PlaceholderModule
try:
import plotly.express as px
import plotly.io as pio
except ImportError:
_plotly = PlaceholderModule("plotly")
px = _plotly.placeholder_attr("express")
pio = _plotly.placeholder_attr("io")
try:
import matplotlib.pyplot as plt
except ImportError:
_matplotlib = PlaceholderModule("matplotlib")
plt = _matplotlib.placeholder_attr("pyplot")
def generate_timeline_plot(
results: list[dict[str, Any]],
output_path: Path,
colors: list[str] | None = None,
itl_thresholds: list[float] | None = None,
labels: list[str] | None = None,
) -> None:
"""
Generate an HTML timeline plot from benchmark results.
Args:
results: List of per-request result dictionaries containing:
- start_time: Request start time (seconds)
- ttft: Time to first token (seconds)
- itl: List of inter-token latencies (seconds)
- latency: Total request latency (seconds)
- prompt_len: Number of prompt tokens
- output_tokens: Number of output tokens
output_path: Path where the HTML file will be saved
colors: List of colors for ITL categories (default: green, orange, red, black)
itl_thresholds: ITL thresholds in seconds (default: [1.0, 4.0, 6.0])
labels: Labels for ITL categories (default based on thresholds)
"""
# Set defaults
if colors is None:
colors = ["#109618", "#FF7F0E", "#D62728"]
if itl_thresholds is None:
itl_thresholds = [0.025, 0.050]
if labels is None:
labels = [
f"ITL < {itl_thresholds[0] * 1000:.0f}ms",
f"{itl_thresholds[0] * 1000:.0f}ms ≤ ITL < {itl_thresholds[1] * 1000:.0f}ms", # noqa
f"ITL ≥ {itl_thresholds[1] * 1000:.0f}ms",
]
labels_colors = {"TTFT": "#636EFA", **dict(zip(labels, colors))}
labels_order = ["TTFT"] + labels
timeline_data = construct_timeline_data(results, itl_thresholds, labels)
if not timeline_data:
print("No timeline data to plot")
return
# Create the plot
fig = px.timeline(
timeline_data,
x_start="start",
x_end="end",
y="request_id",
color="type",
color_discrete_map=labels_colors,
category_orders={"type": labels_order},
hover_data=[
"prompt_tokens",
"output_tokens",
"req_start_time",
"req_finish_time",
"segment_start",
"segment_end",
"duration",
],
)
# Customize hover template to show only time without date
fig.update_traces(
hovertemplate="<b>%{y}</b><br>"
"Type: %{fullData.name}<br>"
"Start: %{customdata[4]}<br>"
"End: %{customdata[5]}<br>"
"Duration: %{customdata[6]}<br>"
"Prompt Tokens: %{customdata[0]}<br>"
"Output Tokens: %{customdata[1]}<br>"
"Request Start Time: %{customdata[2]}<br>"
"Request End Time: %{customdata[3]}<br>"
"<extra></extra>"
)
fig.update_yaxes(autorange="reversed")
fig.update_layout(
xaxis_title="Time",
yaxis_title="Request ID",
showlegend=True,
)
# Save to HTML
pio.write_html(fig, str(output_path))
print(f"Timeline plot saved to: {output_path}")
def construct_timeline_data(
requests_data: list[dict[str, Any]],
itl_thresholds: list[float],
labels: list[str],
) -> list[dict[str, Any]]:
"""
Construct timeline data from request results.
Args:
requests_data: List of per-request result dictionaries
itl_thresholds: ITL thresholds in seconds
labels: Labels for ITL categories
Returns:
List of timeline segments for plotting
"""
def tostr(sec_time: float) -> str:
"""Convert seconds to HH:MM:SS.mmm format."""
h = int(sec_time // 3600)
assert h < 100, "time seems to last more than 100 hours"
m = int((sec_time % 3600) // 60)
s = sec_time % 60
return f"{h:02d}:{m:02d}:{s:06.3f}"
def itl_type(itl: float) -> str:
"""Categorize ITL based on thresholds."""
if itl < itl_thresholds[0]:
return labels[0]
elif itl < itl_thresholds[1]:
return labels[1]
else:
return labels[2]
# Find the earliest start time to use as t0
t0 = None
for request in requests_data:
start_time = request.get("start_time")
if start_time is not None and (t0 is None or start_time < t0):
t0 = start_time
if t0 is None:
return []
timeline_data = []
for i, request in enumerate(requests_data):
start_time = request.get("start_time")
ttft = request.get("ttft")
itl = request.get("itl", [])
latency = request.get("latency")
prompt_len = request.get("prompt_len", 0)
output_tokens = request.get("output_tokens", 0)
# Skip requests without required data
if start_time is None or ttft is None or latency is None:
continue
# Normalize start time
start_time = start_time - t0
start_time_str = tostr(start_time)
# TTFT segment
ttft_end = start_time + ttft
ttft_end_str = tostr(ttft_end)
timeline_data.append(
{
"request_id": f"Req {i}",
"start": start_time_str,
"end": ttft_end_str,
"type": "TTFT",
"prompt_tokens": prompt_len,
"output_tokens": output_tokens,
"req_start_time": tostr(start_time),
"req_finish_time": tostr(start_time + latency),
"segment_start": start_time_str,
"segment_end": ttft_end_str,
"duration": f"{ttft:.3f}s",
}
)
# ITL segments
prev_time = ttft_end
prev_time_str = ttft_end_str
for itl_value in itl:
itl_end = prev_time + itl_value
itl_end_str = tostr(itl_end)
timeline_data.append(
{
"request_id": f"Req {i}",
"start": prev_time_str,
"end": itl_end_str,
"type": itl_type(itl_value),
"prompt_tokens": prompt_len,
"output_tokens": output_tokens,
"req_start_time": tostr(start_time),
"req_finish_time": tostr(start_time + latency),
"segment_start": prev_time_str,
"segment_end": itl_end_str,
"duration": f"{itl_value:.3f}s",
}
)
prev_time = itl_end
prev_time_str = itl_end_str
return timeline_data
def generate_dataset_stats_plot(
results: list[dict[str, Any]],
output_path: Path,
) -> None:
"""
Generate a matplotlib figure with dataset statistics.
Creates a figure with 4 subplots:
- Top-left: Prompt tokens distribution (histogram)
- Top-right: Output tokens distribution (histogram)
- Bottom-left: Prompt+output tokens distribution (histogram)
- Bottom-right: Stacked bar chart (request_id vs tokens)
Args:
results: List of per-request result dictionaries containing:
- prompt_len: Number of prompt tokens
- output_tokens: Number of output tokens
output_path: Path where the figure will be saved
"""
# Extract data
prompt_tokens = []
output_tokens = []
total_tokens = []
for request in results:
prompt_len = request.get("prompt_len", 0)
output_len = request.get("output_tokens", 0)
prompt_tokens.append(prompt_len)
output_tokens.append(output_len)
total_tokens.append(prompt_len + output_len)
if not prompt_tokens:
print("No data available for dataset statistics plot")
return
# Create figure with 4 subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
# Top-left: Prompt tokens distribution
ax1.hist(prompt_tokens, bins=30, color="steelblue", edgecolor="black", alpha=0.7)
ax1.set_xlabel("Prompt Tokens")
ax1.set_ylabel("Frequency")
ax1.set_title("Prompt Tokens Distribution")
ax1.grid(True, alpha=0.3)
# Top-right: Output tokens distribution
ax2.hist(output_tokens, bins=30, color="coral", edgecolor="black", alpha=0.7)
ax2.set_xlabel("Output Tokens")
ax2.set_ylabel("Frequency")
ax2.set_title("Output Tokens Distribution")
ax2.grid(True, alpha=0.3)
# Bottom-left: Prompt+output tokens distribution
ax3.hist(
total_tokens, bins=30, color="mediumseagreen", edgecolor="black", alpha=0.7
)
ax3.set_xlabel("Total Tokens (Prompt + Output)")
ax3.set_ylabel("Frequency")
ax3.set_title("Total Tokens Distribution")
ax3.grid(True, alpha=0.3)
# Bottom-right: Stacked bar chart
request_ids = list(range(len(prompt_tokens)))
ax4.bar(
request_ids, prompt_tokens, label="Prompt Tokens", color="steelblue", alpha=0.7
)
ax4.bar(
request_ids,
output_tokens,
bottom=prompt_tokens,
label="Output Tokens",
color="coral",
alpha=0.7,
)
ax4.set_xlabel("Request ID")
ax4.set_ylabel("Tokens")
ax4.set_title("Tokens per Request (Stacked)")
ax4.legend()
ax4.grid(True, alpha=0.3, axis="y")
# Adjust layout to prevent overlap
plt.tight_layout()
# Save figure
plt.savefig(str(output_path), dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"Dataset statistics plot saved to: {output_path}")

View File

@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Literal
import aiohttp
@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format(
write_to_json(pt_file, pt_records)
def compute_result_filename(
args: argparse.Namespace,
model_id: str,
label: str,
current_dt: str,
) -> str | None:
"""Compute the result filename based on benchmark configuration.
Args:
args: Command line arguments containing result configuration
model_id: The model identifier
label: The benchmark label
current_dt: Current datetime string
Returns:
The computed filename path or None if no result saving is requested
"""
if not (args.plot_timeline or args.save_result or args.append_result):
return None
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
return file_name
def add_cli_args(parser: argparse.ArgumentParser):
add_dataset_parser(parser)
parser.add_argument(
@@ -1277,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true")
@@ -1535,6 +1580,30 @@ def add_cli_args(parser: argparse.ArgumentParser):
"connecting to servers with self-signed certificates.",
)
parser.add_argument(
"--plot-timeline",
action="store_true",
help="Generate an HTML timeline plot showing request execution. "
"The plot will be saved alongside the results JSON file.",
)
parser.add_argument(
"--timeline-itl-thresholds",
type=float,
nargs=2,
default=[25.0, 50.0],
metavar=("THRESHOLD1", "THRESHOLD2"),
help="ITL thresholds in milliseconds for timeline plot coloring. "
"Specify two values to categorize inter-token latencies into three groups: "
"below first threshold (green), between thresholds (orange), "
"and above second threshold (red). Default: 25 50 (milliseconds).",
)
parser.add_argument(
"--plot-dataset-stats",
action="store_true",
help="Generate a matplotlib figure with dataset statistics showing "
"prompt tokens, output tokens, and combined token distributions.",
)
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
@@ -1770,6 +1839,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Merge with benchmark result
result_json = {**result_json, **benchmark_result}
# Compute file_name once before using it for plots or saving results
file_name = compute_result_filename(args, model_id, label, current_dt)
# Generate timeline plot if requested
if args.plot_timeline:
try:
from vllm.benchmarks.plot import generate_timeline_plot
# Prepare per-request data for timeline
per_request_data = []
start_times = benchmark_result.get("start_times", [])
ttfts = benchmark_result.get("ttfts", [])
itls = benchmark_result.get("itls", [])
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if start_times and ttfts and itls:
for i in range(len(start_times)):
# Calculate latency as ttft + sum of all itls
latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i]
per_request_data.append(
{
"start_time": start_times[i],
"ttft": ttfts[i],
"itl": itls[i],
"latency": latency,
"prompt_len": input_lens[i],
"output_tokens": output_lens[i],
}
)
timeline_path = Path(file_name).with_suffix(".timeline.html")
# Convert thresholds from milliseconds to seconds
itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds]
generate_timeline_plot(
per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec
)
else:
warnings.warn(
"Timeline plot requires detailed metrics. "
"Ensure the benchmark completed successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2)
# Generate dataset statistics plot if requested
if args.plot_dataset_stats:
try:
from vllm.benchmarks.plot import generate_dataset_stats_plot
# Prepare per-request data for dataset stats
per_request_data = []
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if input_lens and output_lens:
for req_input_len, req_output_len in zip(input_lens, output_lens):
per_request_data.append(
{
"prompt_len": req_input_len,
"output_tokens": req_output_len,
}
)
stats_path = Path(file_name).with_suffix(".dataset_stats.png")
generate_dataset_stats_plot(per_request_data, stats_path)
else:
warnings.warn(
"Dataset statistics plot requires input and "
"output length data. Ensure the benchmark completed "
"successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(
f"Failed to generate dataset statistics plot: {e}", stacklevel=2
)
if not args.save_detailed:
# Remove fields with too many data points
for field in [
@@ -1788,22 +1937,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Save to file
if args.save_result or args.append_result:
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
) as outfile:

View File

@@ -10,14 +10,14 @@ from .plot_pareto import SweepPlotParetoArgs
from .plot_pareto import main as plot_pareto_main
from .serve import SweepServeArgs
from .serve import main as serve_main
from .serve_sla import SweepServeSLAArgs
from .serve_sla import main as serve_sla_main
from .serve_workload import SweepServeWorkloadArgs
from .serve_workload import main as serve_workload_main
from .startup import SweepStartupArgs
from .startup import main as startup_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
(SweepServeSLAArgs, serve_sla_main),
(SweepServeWorkloadArgs, serve_workload_main),
(SweepStartupArgs, startup_main),
(SweepPlotArgs, plot_main),
(SweepPlotParetoArgs, plot_pareto_main),

View File

@@ -324,6 +324,11 @@ def _plot_fig(
df = filter_by.apply(df)
df = bin_by.apply(df)
if len(df) == 0:
print(f"No data to plot. Filters: {filter_by}")
print("[END FIGURE]")
return
# Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by:
df = df.sort_values(by=curve_by)
@@ -494,7 +499,7 @@ class SweepPlotArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
output_dir = Path(args.EXPERIMENT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
@@ -526,11 +531,9 @@ class SweepPlotArgs:
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
"OUTPUT_DIR",
"EXPERIMENT_DIR",
type=str,
default="results",
help="The directory containing the results to plot, "
"i.e., the `--output-dir` argument to the parameter sweep script.",
help="The directory containing the sweep results to plot.",
)
parser.add_argument(
"--fig-dir",
@@ -570,13 +573,13 @@ class SweepPlotArgs:
parser.add_argument(
"--var-x",
type=str,
default="request_throughput",
default="total_token_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
default="p99_ttft_ms",
default="median_ttft_ms",
help="The variable for the y-axis",
)
parser.add_argument(

View File

@@ -325,7 +325,7 @@ class SweepPlotParetoArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
output_dir = Path(args.EXPERIMENT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
@@ -342,9 +342,8 @@ class SweepPlotParetoArgs:
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"OUTPUT_DIR",
"EXPERIMENT_DIR",
type=str,
default="results",
help="The directory containing the sweep results to plot.",
)
parser.add_argument(

View File

@@ -4,6 +4,7 @@ import argparse
import contextlib
import json
import shlex
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@@ -135,17 +136,21 @@ def run_benchmark(
def _get_comb_base_path(
output_dir: Path,
experiment_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
*,
extra_parts: tuple[str, ...] = (),
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.name))
if bench_comb:
parts.extend(("BENCH-", bench_comb.name))
if extra_parts:
parts.extend(extra_parts)
return output_dir / sanitize_filename("-".join(parts))
return experiment_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None):
@@ -158,10 +163,10 @@ def _get_comb_run_path(base_path: Path, run_number: int | None):
def _comb_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
output_dir: Path,
experiment_dir: Path,
):
for bench_comb in bench_combs:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
if not _get_comb_run_path(base_path, run_number=None).exists():
return True
@@ -175,11 +180,11 @@ def server_ctx(
show_stdout: bool,
serve_comb: ParameterSweepItem,
bench_params: ParameterSweep,
output_dir: Path,
experiment_dir: Path,
dry_run: bool,
server_ready_timeout: int = 300,
):
if not _comb_needs_server(serve_comb, bench_params, output_dir):
if not _comb_needs_server(serve_comb, bench_params, experiment_dir):
return contextlib.nullcontext()
return run_server(
@@ -211,10 +216,10 @@ def run_comb(
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
base_path: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
return None
@@ -253,10 +258,10 @@ def run_combs(
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
output_dir: Path,
link_vars: list[tuple[str, str]],
experiment_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
@@ -266,22 +271,22 @@ def run_combs(
show_stdout=show_stdout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
experiment_dir=experiment_dir,
dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
) as server:
for bench_comb in bench_params:
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
comb_data = run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
link_vars=link_vars,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
@@ -291,7 +296,7 @@ def run_combs(
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@@ -305,11 +310,12 @@ class SweepServeArgs:
server_ready_timeout: int
serve_params: ParameterSweep
bench_params: ParameterSweep
link_vars: list[tuple[str, str]]
output_dir: Path
experiment_name: str
num_runs: int
dry_run: bool
resume: str | None
link_vars: list[tuple[str, str]]
resume: bool
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@@ -336,6 +342,11 @@ class SweepServeArgs:
link_vars = cls.parse_link_vars(args.link_vars)
if args.experiment_name:
experiment_name = args.experiment_name
else:
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -347,11 +358,12 @@ class SweepServeArgs:
show_stdout=args.show_stdout,
serve_params=serve_params,
bench_params=bench_params,
link_vars=link_vars,
output_dir=Path(args.output_dir),
experiment_name=experiment_name,
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
link_vars=link_vars,
server_ready_timeout=args.server_ready_timeout,
)
@@ -388,6 +400,7 @@ class SweepServeArgs:
default=300,
help="Timeout in seconds to wait for the server to become ready.",
)
parser.add_argument(
"--serve-params",
type=str,
@@ -398,6 +411,16 @@ class SweepServeArgs:
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
parser.add_argument(
"--link-vars",
type=str,
default="",
help=(
"Comma-separated list of linked variables between serve and bench, "
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
),
)
parser.add_argument(
"--bench-params",
type=str,
@@ -413,7 +436,15 @@ class SweepServeArgs:
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
help="The main directory to which results are written.",
)
parser.add_argument(
"-e",
"--experiment-name",
type=str,
default=None,
help="The name of this experiment (defaults to current timestamp). "
"Results will be stored under `output_dir/experiment_name`.",
)
parser.add_argument(
"--num-runs",
@@ -429,21 +460,10 @@ class SweepServeArgs:
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--link-vars",
type=str,
default="",
help=(
"Comma-separated list of linked variables between serve and bench, "
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
),
action="store_true",
help="Resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files "
"under `output_dir/experiment_name`.",
)
return parser
@@ -458,33 +478,52 @@ class SweepServeArgs:
pairs.append((a.strip(), b.strip()))
return pairs
def resolve_experiment_dir(self) -> Path:
experiment_dir = self.output_dir / self.experiment_name
def run_main(args: SweepServeArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if self.resume:
if not experiment_dir.exists():
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
else:
if experiment_dir.exists():
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
return experiment_dir
@contextmanager
def run_ctx(self, experiment_dir: Path):
if self.dry_run:
yield
print(f"Experiment will be saved at: {experiment_dir}")
return
try:
yield
print(f"Experiment has been saved at: {experiment_dir}")
except BaseException as exc:
raise RuntimeError(
"The script was terminated early. Use `--resume` "
"to continue the script from its last checkpoint."
) from exc
def run_main(args: SweepServeArgs):
experiment_dir = args.resolve_experiment_dir()
with args.run_ctx(experiment_dir):
return run_combs(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
link_vars=args.link_vars,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
output_dir=output_dir,
experiment_dir=experiment_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):

View File

@@ -1,305 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import ClassVar, Literal, get_args
import numpy as np
from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import (
SweepServeArgs,
_get_comb_base_path,
run_comb,
server_ctx,
)
from .server import ServerProcess
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
SLAVariable = Literal["request_rate", "max_concurrency"]
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if sla_variable == "request_rate":
return request_throughput
if sla_variable == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(sla_variable)
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
def run_comb_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
sla_variable: SLAVariable,
sla_value: int,
) -> list[dict[str, object]] | None:
bench_comb_sla = bench_comb | {sla_variable: sla_value}
return run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb_sla,
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
def explore_sla(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
print("[SLA START]")
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
print(f"Number of SLA iterations: {sla_iters}")
if sla_iters < 2:
raise ValueError("`sla_iters` should be at least 2")
serial_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=1,
)
batch_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
)
if serial_comb_data is None or batch_comb_data is None:
if dry_run:
print("Omitting intermediate SLA iterations.")
print("[SLA END]")
return
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
print(f"Serial inference: {sla_variable}={serial_sla_value}")
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
print(f"Batch inference: {sla_variable}={batch_sla_value}")
# Avoid duplicated runs for intermediate values if the range between
# `serial_sla_value` and `batch_sla_value` is small
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
inter_sla_values = sorted(set(map(round, inter_sla_values)))
inter_combs_data: list[dict[str, object]] = []
for inter_sla_value in inter_sla_values:
print(f"Exploring: {sla_variable}={inter_sla_value}")
inter_comb_data = run_comb_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
sla_variable=sla_variable,
sla_value=inter_sla_value,
)
if inter_comb_data is not None:
inter_combs_data.extend(inter_comb_data)
print("[SLA END]")
return serial_comb_data + inter_combs_data + batch_comb_data
def run_slas(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
sla_variable: SLAVariable,
sla_iters: int,
output_dir: Path,
num_runs: int,
dry_run: bool,
link_vars: list[tuple[str, str]],
):
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
"since it is supposed to be determined automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
server_ready_timeout=server_ready_timeout,
serve_comb=serve_comb,
bench_params=bench_params,
output_dir=output_dir,
dry_run=dry_run,
) as server:
for bench_comb in bench_params:
comb_data = explore_sla(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
sla_variable=sla_variable,
sla_iters=sla_iters,
output_dir=output_dir,
num_runs=num_runs,
dry_run=dry_run,
link_vars=link_vars,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeSLAArgs(SweepServeArgs):
sla_variable: SLAVariable
sla_iters: int
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = (
"Explore the latency-throughput space for determining SLAs."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-variable",
type=str,
choices=get_args(SLAVariable),
default="request_rate",
help="The variable to adjust in each iteration.",
)
sla_group.add_argument(
"--sla-iters",
type=int,
default=10,
help="Number of iterations used to explore the latency-throughput space. "
"This includes the first two iterations used to interpolate the value of "
"`sla_variable` for remaining iterations.",
)
return parser
def run_main(args: SweepServeSLAArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
return run_slas(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
sla_variable=args.sla_variable,
sla_iters=args.sla_iters,
output_dir=output_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
link_vars=args.link_vars,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):
run_main(SweepServeSLAArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -0,0 +1,328 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import ClassVar, Literal, get_args
import numpy as np
from typing_extensions import assert_never
from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import (
SweepServeArgs,
_get_comb_base_path,
run_comb,
server_ctx,
)
from .server import ServerProcess
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
WorkloadVariable = Literal["request_rate", "max_concurrency"]
def _estimate_workload_value(
run_data: dict[str, object],
workload_var: WorkloadVariable,
):
request_throughput = float(run_data["request_throughput"]) # type: ignore
if workload_var == "request_rate":
return request_throughput
if workload_var == "max_concurrency":
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
return request_throughput * mean_latency_ms / 1000
assert_never(workload_var)
def _estimate_workload_avg(
runs: list[dict[str, object]],
workload_var: WorkloadVariable,
):
total = sum(_estimate_workload_value(run, workload_var) for run in runs)
return total / len(runs)
def run_comb_workload(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
experiment_dir: Path,
num_runs: int,
dry_run: bool,
workload_var: WorkloadVariable,
workload_value: int,
) -> list[dict[str, object]] | None:
bench_comb_workload = bench_comb | {workload_var: workload_value}
return run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb_workload,
link_vars=link_vars,
base_path=_get_comb_base_path(
experiment_dir,
serve_comb,
bench_comb,
extra_parts=("WL-", f"{workload_var}={workload_value}"),
),
num_runs=num_runs,
dry_run=dry_run,
)
def explore_comb_workloads(
server: ServerProcess | None,
bench_cmd: list[str],
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
link_vars: list[tuple[str, str]],
workload_var: WorkloadVariable,
workload_iters: int,
experiment_dir: Path,
num_runs: int,
dry_run: bool,
):
print("[WL START]")
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
print(f"Number of workload iterations: {workload_iters}")
if workload_iters < 2:
raise ValueError("`workload_iters` should be at least 2")
dataset_size = DEFAULT_NUM_PROMPTS
if "num_prompts" in bench_comb:
dataset_size = int(bench_comb["num_prompts"]) # type: ignore
else:
for i, arg in enumerate(bench_cmd):
if arg == "--num-prompts" and i + 1 < len(bench_cmd):
dataset_size = int(bench_cmd[i + 1])
break
elif arg.startswith("--num-prompts="):
dataset_size = int(arg.split("=", 1)[1])
break
print(f"Dataset size: {dataset_size}")
serial_workload_data = run_comb_workload(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {"max_concurrency": 1},
link_vars=link_vars,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
workload_var=workload_var,
workload_value=1,
)
batch_workload_data = run_comb_workload(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb | {"max_concurrency": dataset_size},
link_vars=link_vars,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
workload_var=workload_var,
workload_value=dataset_size,
)
if serial_workload_data is None or batch_workload_data is None:
if dry_run:
print("Omitting intermediate Workload iterations.")
print("[WL END]")
return
serial_workload_value = math.ceil(
_estimate_workload_avg(serial_workload_data, workload_var)
)
print(f"Serial inference: {workload_var}={serial_workload_value}")
batch_workload_value = math.floor(
_estimate_workload_avg(batch_workload_data, workload_var)
)
print(f"Batch inference: {workload_var}={batch_workload_value}")
# Avoid duplicated runs for intermediate values if the range between
# `serial_workload_value` and `batch_workload_value` is small
inter_workload_values = np.linspace(
serial_workload_value, batch_workload_value, workload_iters
)[1:-1]
inter_workload_values = sorted(set(map(round, inter_workload_values)))
inter_workloads_data: list[dict[str, object]] = []
for inter_workload_value in inter_workload_values:
print(f"Exploring: {workload_var}={inter_workload_value}")
inter_workload_data = run_comb_workload(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
link_vars=link_vars,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
workload_var=workload_var,
workload_value=inter_workload_value,
)
if inter_workload_data is not None:
inter_workloads_data.extend(inter_workload_data)
print("[WL END]")
return serial_workload_data + inter_workloads_data + batch_workload_data
def explore_combs_workloads(
serve_cmd: list[str],
bench_cmd: list[str],
after_bench_cmd: list[str],
*,
show_stdout: bool,
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
link_vars: list[tuple[str, str]],
workload_var: WorkloadVariable,
workload_iters: int,
experiment_dir: Path,
num_runs: int,
dry_run: bool,
):
if any(bench_comb.has_param(workload_var) for bench_comb in bench_params):
raise ValueError(
f"You should not override `{workload_var}` in `bench_params` "
"since it is supposed to be explored automatically."
)
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
with server_ctx(
serve_cmd,
after_bench_cmd,
show_stdout=show_stdout,
server_ready_timeout=server_ready_timeout,
serve_comb=serve_comb,
bench_params=bench_params,
experiment_dir=experiment_dir,
dry_run=dry_run,
) as server:
for bench_comb in bench_params:
comb_data = explore_comb_workloads(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
link_vars=link_vars,
workload_var=workload_var,
workload_iters=workload_iters,
experiment_dir=experiment_dir,
num_runs=num_runs,
dry_run=dry_run,
)
if comb_data is not None:
all_data.extend(comb_data)
if dry_run:
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@dataclass
class SweepServeWorkloadArgs(SweepServeArgs):
workload_var: WorkloadVariable
workload_iters: int
parser_name: ClassVar[str] = "serve_workload"
parser_help: ClassVar[str] = (
"Explore the latency-throughput tradeoff for different workload levels."
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
base_args = SweepServeArgs.from_cli_args(args)
return cls(
**asdict(base_args),
workload_var=args.workload_var,
workload_iters=args.workload_iters,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser)
workload_group = parser.add_argument_group("workload options")
workload_group.add_argument(
"--workload-var",
type=str,
choices=get_args(WorkloadVariable),
default="request_rate",
help="The variable to adjust in each iteration.",
)
workload_group.add_argument(
"--workload-iters",
type=int,
default=10,
help="Number of workload levels to explore. "
"This includes the first two iterations used to interpolate the value of "
"`workload_var` for remaining iterations.",
)
return parser
def run_main(args: SweepServeWorkloadArgs):
experiment_dir = args.resolve_experiment_dir()
with args.run_ctx(experiment_dir):
return explore_combs_workloads(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
link_vars=args.link_vars,
workload_var=args.workload_var,
workload_iters=args.workload_iters,
experiment_dir=experiment_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
)
def main(args: argparse.Namespace):
run_main(SweepServeWorkloadArgs.from_cli_args(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help)
SweepServeWorkloadArgs.add_cli_args(parser)
main(parser.parse_args())

View File

@@ -4,6 +4,7 @@ import argparse
import json
import shlex
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
@@ -111,7 +112,7 @@ def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
def _get_comb_base_path(
output_dir: Path,
experiment_dir: Path,
serve_comb: ParameterSweepItem,
startup_comb: ParameterSweepItem,
) -> Path:
@@ -120,7 +121,8 @@ def _get_comb_base_path(
parts.extend(("SERVE-", serve_comb.name))
if startup_comb:
parts.extend(("STARTUP-", startup_comb.name))
return output_dir / sanitize_filename("-".join(parts))
return experiment_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
@@ -225,7 +227,7 @@ def run_combs(
*,
serve_params: ParameterSweep,
startup_params: ParameterSweep,
output_dir: Path,
experiment_dir: Path,
num_runs: int,
show_stdout: bool,
dry_run: bool,
@@ -233,7 +235,7 @@ def run_combs(
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
for startup_comb in startup_params:
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
base_path = _get_comb_base_path(experiment_dir, serve_comb, startup_comb)
comb_data = run_comb(
startup_cmd,
serve_comb=serve_comb,
@@ -250,7 +252,7 @@ def run_combs(
return None
combined_df = pd.DataFrame.from_records(all_data)
combined_df.to_csv(output_dir / "summary.csv")
combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@@ -260,11 +262,11 @@ class SweepStartupArgs:
serve_params: ParameterSweep
startup_params: ParameterSweep
output_dir: Path
experiment_name: str
num_runs: int
show_stdout: bool
dry_run: bool
resume: str | None
strict_params: bool
resume: bool
parser_name: ClassVar[str] = "startup"
parser_help: ClassVar[str] = (
@@ -286,13 +288,19 @@ class SweepStartupArgs:
startup_params = ParameterSweep.from_records([{}])
supported = _get_supported_startup_keys()
strict_params = args.strict_params
serve_params = _filter_params(
serve_params, supported=supported, strict=args.strict_params
serve_params, supported=supported, strict=strict_params
)
startup_params = _filter_params(
startup_params, supported=supported, strict=args.strict_params
startup_params, supported=supported, strict=strict_params
)
if args.experiment_name:
experiment_name = args.experiment_name
else:
experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
if args.num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -301,11 +309,11 @@ class SweepStartupArgs:
serve_params=serve_params,
startup_params=startup_params,
output_dir=Path(args.output_dir),
experiment_name=experiment_name,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
resume=args.resume,
strict_params=args.strict_params,
)
@classmethod
@@ -316,6 +324,7 @@ class SweepStartupArgs:
default="vllm bench startup",
help="The command used to run the startup benchmark.",
)
parser.add_argument(
"--serve-params",
type=str,
@@ -331,12 +340,27 @@ class SweepStartupArgs:
help="Path to JSON file containing parameter combinations "
"for the `vllm bench startup` command.",
)
parser.add_argument(
"--strict-params",
action="store_true",
help="If set, unknown parameters in sweep files raise an error "
"instead of being ignored.",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
help="The directory to which results are written.",
help="The main directory to which results are written.",
)
parser.add_argument(
"-e",
"--experiment-name",
type=str,
default=None,
help="The name of this experiment (defaults to current timestamp). "
"Results will be stored under `output_dir/experiment_name`.",
)
parser.add_argument(
"--num-runs",
@@ -357,43 +381,56 @@ class SweepStartupArgs:
)
parser.add_argument(
"--resume",
type=str,
default=None,
help="Set this to the name of a directory under `output_dir` (which is a "
"timestamp) to resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files.",
)
parser.add_argument(
"--strict-params",
action="store_true",
help="If set, unknown parameters in sweep files raise an error "
"instead of being ignored.",
help="Resume a previous execution of this script, i.e., only run "
"parameter combinations for which there are still no output files "
"under `output_dir/experiment_name`.",
)
return parser
def resolve_experiment_dir(self) -> Path:
experiment_dir = self.output_dir / self.experiment_name
if self.resume:
if not experiment_dir.exists():
raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
else:
if experiment_dir.exists():
raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
return experiment_dir
@contextmanager
def run_ctx(self, experiment_dir: Path):
if self.dry_run:
yield
print(f"Experiment will be saved at: {experiment_dir}")
return
try:
yield
print(f"Experiment has been saved at: {experiment_dir}")
except BaseException as exc:
raise RuntimeError(
"The script was terminated early. Use `--resume` "
"to continue the script from its last checkpoint."
) from exc
def run_main(args: SweepStartupArgs):
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = args.output_dir / timestamp
experiment_dir = args.resolve_experiment_dir()
if args.resume and not output_dir.exists():
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
try:
with args.run_ctx(experiment_dir):
return run_combs(
startup_cmd=args.startup_cmd,
serve_params=args.serve_params,
startup_params=args.startup_params,
output_dir=output_dir,
experiment_dir=experiment_dir,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
)
except BaseException as exc:
raise RuntimeError(
f"The script was terminated early. Use `--resume {timestamp}` "
f"to continue the script from its last checkpoint."
) from exc
def main(args: argparse.Namespace):

View File

@@ -282,49 +282,6 @@ class CompilerManager:
maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
# There is a compilation time optimization here.
#
# If the (input metadata, graph, compiler config) are the same, then
# we want to avoid compiling the same artifact again. If we didn't
# do this optimization, the backend compilation (InductorAdaptor or
# InductorStandaloneAdaptor)
# is able to cache hit and produce an artifact faster if it was
# already created, but it is still a duplicate artifact that
# requires unnecessary things e.g. disk IO.
#
# The optimization is: If the backend compilation cache hits,
# then do an early return from the backend compilation and look up
# which of the previous in-memory artifacts we created to reuse.
#
# We implemented this by monkey-patching torch (torch does not
# easily expose the cache_key function), but in the future torch
# should expose the cache_key function that we can just call
# directly before invoking backend compilation.
cache_key = None
orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
def autograd_cache_key(*args, **kwargs):
result = orig(*args, **kwargs)
if result is None:
return None
nonlocal cache_key
cache_key = result[0]
if cache_key in self.loaded_artifacts:
raise StopCompiling()
return result
from unittest.mock import patch
with (
# Graphs that are isometric (different node names but same
# structure) should be treated as the same.
torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
patch(
"torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
autograd_cache_key,
),
):
try:
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
@@ -332,11 +289,6 @@ class CompilerManager:
compile_range,
maybe_key,
)
except StopCompiling:
assert cache_key is not None
return self.loaded_artifacts[cache_key]
if cache_key is not None and compiled_graph is not None:
self.loaded_artifacts[cache_key] = compiled_graph
assert compiled_graph is not None, "Failed to compile the graph"
@@ -497,7 +449,7 @@ def wrap_with_cudagraph_if_needed(
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
return static_graph_wrapper_class(
runnable=piecewise_backend,
runnable=piecewise_backend.graph.forward,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
@@ -780,7 +732,7 @@ class VllmBackend:
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
def configure_post_pass(self) -> None:
# self.pass_manager.configure(self.vllm_config)
self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
@@ -846,7 +798,7 @@ class VllmBackend:
),
)
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any], **kwargs) -> Any:
from .caching import (
VllmSerializableFunction,
)
@@ -988,7 +940,7 @@ class VllmBackend:
assert not self._called, "VllmBackend can only be called once"
self.graph = graph
self.configure_post_pass()
# self.configure_post_pass()
if self.compilation_config.use_inductor_graph_partition:
# Let Inductor decide partitioning; avoid FX-level pre-splitting.

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import hashlib
import inspect
import os
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
self.loaded_submodule_store = {}
@contextlib.contextmanager
def patch_pytree_map_over_slice():
pytree._private_register_pytree_node(
slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x)
)
try:
yield
finally:
pytree._deregister_pytree_node(slice)
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
with (
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
patch_pytree_map_over_slice(),
):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
@@ -261,6 +277,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
with patch_pytree_map_over_slice():
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)

View File

@@ -184,6 +184,47 @@ def is_compile_cache_enabled(
)
def _patch_standalone_compile_atomic_save() -> None:
"""Backport of pytorch/pytorch#162432 for torch < 2.10.0.
Patches CompiledArtifact.save() to use write_atomic for binary format,
preventing corrupt cache files when multiple processes compile
concurrently.
"""
from torch._inductor.codecache import write_atomic
from torch._inductor.standalone_compile import CompiledArtifact as cls
if getattr(cls.save, "_vllm_patched", False):
return
original_save = cls.save
def _save(
self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format != "binary":
return original_save(self, path=path, format=format)
from torch._dynamo.utils import dynamo_timed
from torch._inductor.codecache import torch_key
from torch.utils._appending_byte_serializer import BytesWriter
with dynamo_timed("CompiledArtifact.save"):
assert self._artifacts is not None
artifact_bytes, cache_info = self._artifacts
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
key = cache_info.aot_autograd_artifacts[0]
assert not os.path.isdir(path)
writer = BytesWriter()
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
write_atomic(path, writer.to_bytes())
_save._vllm_patched = True # type: ignore[attr-defined]
cls.save = _save # type: ignore[assignment]
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
@@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
if not is_torch_equal_or_newer("2.10.0"):
_patch_standalone_compile_atomic_save()
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str:
@@ -224,7 +267,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, compile_range)
set_functorch_config()
# set_functorch_config()
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
@@ -325,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format=self.save_format
)
compilation_counter.num_compiled_artifacts_loaded += 1
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
@@ -395,7 +439,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, compile_range)
set_functorch_config()
# set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980

View File

@@ -29,6 +29,8 @@ class CompilationCounter:
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
# The number of standalone_compile compiled artifacts loaded from cache
num_compiled_artifacts_loaded: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0

View File

@@ -21,6 +21,7 @@ from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
@@ -298,12 +299,10 @@ class CUDAGraphWrapper:
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
# output = weak_ref_tensors(output)
output = self.weak_ref_tensors_with_intermediate(output)
# here we always use weak ref for the output
# to save memory
# entry.output = weak_ref_tensors(output)
entry.output = self.weak_ref_tensors_with_intermediate(output)
entry.cudagraph = cudagraph

View File

@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
"avg",
"sum",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
"avg",
"sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
"avg",
"sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,

View File

@@ -5,7 +5,6 @@ import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
kFp8Dynamic128Sym,
)
from vllm.platforms import current_platform
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
logger.debug("Replaced %s patterns", self.matched_count)
logger.debug(
"%s Replaced %s patterns", self.__class__.__name__, self.matched_count
)
def uuid(self) -> str:
fusion_patterns = [
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
def __init__(self, quant_op: OpOverload) -> None:
def __init__(self) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
self.quant_op = quant_op
self.quant_matcher = MatcherQuantFP8(
quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
)
def get_inputs(self) -> list[torch.Tensor]:
return [
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
at2 = self.quant_op(at1, 128)
at2 = self.quant_matcher(at1)
return at2[0], at2[1]
def replacement(
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
for quant_op in self.QUANT_OPS:
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
self.dump_patterns(config, self.patterns)

View File

@@ -18,7 +18,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
@@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
)
FP8_DTYPE = current_platform.fp8_dtype()
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self,

View File

@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
rope_targets = [torch.ops._C.rotary_embedding.default]
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
rope_targets.append(
torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
)
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
kwargs = node.kwargs
at_target = node.args[0]
if at_target == torch.ops._C.rotary_embedding.default:
if at_target in rope_targets:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)

View File

@@ -298,9 +298,9 @@ class PiecewiseBackend:
else list(args)
)
with (
torch._functorch.config.patch("bundled_autograd_cache", True),
):
# with (
# torch._functorch.config.patch("bundled_autograd_cache", True),
# ):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,

View File

@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
yield
finally:
self.__class__.forward.__code__ = original
def reset_compile_wrapper(model: torch.nn.Module) -> None:
"""
Clean up compiled model and captured CUDA graphs for elastic EP.
"""
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
model, "model"
):
model = model.model
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
return
# model.do_not_compile is set by the @support_torch_compile decorator
if hasattr(model, "do_not_compile") and model.do_not_compile:
return
from vllm.compilation.counter import compilation_counter
# reset the compilation counter
compilation_counter.num_models_seen = 0
compilation_counter.num_graphs_seen = 0
compilation_counter.num_piecewise_graphs_seen = 0
compilation_counter.num_piecewise_capturable_graphs_seen = 0
compilation_counter.num_backend_compilations = 0
compilation_counter.num_gpu_runner_capture_triggers = 0
compilation_counter.num_cudagraph_captured = 0
compilation_counter.num_inductor_compiles = 0
compilation_counter.num_eager_compiles = 0
compilation_counter.num_cache_entries_updated = 0
compilation_counter.num_compiled_artifacts_saved = 0
compilation_counter.stock_torch_compile_count = 0
# Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py
# __call__ uses the stale aot_compiled_fn whose torchinductor
# kernels have old parameters (expert_map size for example)
# baked in as compile-time constants.
if hasattr(model, "aot_compiled_fn"):
model.aot_compiled_fn = None
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
model.was_aot_compile_fn_loaded_from_disk = False
# Reset the cache_dir so VllmBackend recomputes the hash
# (data_parallel_size changed, so the config hash differs).
compilation_config = model.vllm_config.compilation_config
compilation_config.cache_dir = ""
compilation_config.local_cache_dir = ""
model.__class__.forward.__code__ = model.original_code_object()
TorchCompileWithNoGuardsWrapper.__init__(model)

View File

@@ -16,8 +16,8 @@ class AttentionConfig:
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
flash_attn_version: Literal[2, 3] | None = None
"""Force vllm to use a specific flash-attention version (2 or 3).
flash_attn_version: Literal[2, 3, 4] | None = None
"""Force vllm to use a specific flash-attention version (2, 3, or 4).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False

View File

@@ -87,8 +87,15 @@ class CUDAGraphMode(enum.Enum):
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
def valid_runtime_modes(self) -> bool:
return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
def decode_use_graph(self) -> bool:
return self.decode_mode() == CUDAGraphMode.FULL
@classmethod
def valid_runtime_modes(cls) -> frozenset["CUDAGraphMode"]:
return frozenset({cls.NONE, cls.PIECEWISE, cls.FULL})
def is_valid_runtime_mode(self) -> bool:
return self in CUDAGraphMode.valid_runtime_modes()
def __str__(self) -> str:
return self.name
@@ -385,7 +392,7 @@ class CompilationConfig:
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
mode: CompilationMode = Field(default=None)
mode: CompilationMode = Field(default=CompilationMode.NONE)
"""The compilation approach used for torch.compile-based compilation of the
model.
@@ -503,7 +510,7 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
cudagraph_mode: CUDAGraphMode = Field(default=None)
cudagraph_mode: CUDAGraphMode = Field(default=CUDAGraphMode.FULL_DECODE_ONLY)
"""
The mode of the cudagraph:
@@ -1003,6 +1010,7 @@ class CompilationConfig:
# https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition:
self.splitting_ops.append("vllm::unified_kv_cache_update")
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
elif len(self.splitting_ops) == 0:
if (
@@ -1045,7 +1053,7 @@ class CompilationConfig:
"are optimized for prefill and are incompatible with CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter."
"deepep_low_latency or allgather_reducescatter."
)
self.cudagraph_mode = CUDAGraphMode.NONE

View File

@@ -50,8 +50,6 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.v1.attention.backends.registry import AttentionBackendEnum
import os
if TYPE_CHECKING:
from transformers import PretrainedConfig
@@ -128,6 +126,7 @@ class ModelConfig:
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins."""
trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
@@ -463,8 +462,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
from vllm.platforms import current_platform
if self.override_attention_dtype is not None and not current_platform.is_rocm():
warnings.warn(
"override-attention-dtype is set but not using ROCm platform",
@@ -474,9 +471,8 @@ class ModelConfig:
if self.enable_sleep_mode and not current_platform.is_sleep_mode_available():
raise ValueError("Sleep mode is not supported on current platform.")
temp_hf_config_path = os.environ.get("CUSTOM_QUANT_CONFIG", None)
hf_config = get_config(
temp_hf_config_path or self.hf_config_path or self.model,
self.hf_config_path or self.model,
self.trust_remote_code,
self.revision,
self.code_revision,
@@ -622,6 +618,16 @@ class ModelConfig:
self._try_verify_and_update_model_config()
self._verify_quantization()
self._verify_cuda_graph()
import os
enforce_cuda_graph = os.environ.get("VLLM_ENFORCE_CUDA_GRAPH",None)
if enforce_cuda_graph is not None and enforce_cuda_graph in ["1", "y", "Y"]:
self.enforce_eager = False
else:
self.enforce_eager = True
logger.warning_once(
"Please export VLLM_ENFORCE_CUDA_GRAPH=1 to enable cuda graph. "
"For now, cuda graph is not used and --enforce-eager is disabled ,"
"we are trying to use cuda graph as the default mode")
self._verify_bnb_config()
def get_model_arch_config(
@@ -886,6 +892,7 @@ class ModelConfig:
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"modelopt_mixed",
"petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
@@ -942,8 +949,6 @@ class ModelConfig:
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
from vllm.platforms import current_platform
current_platform.verify_quantization(self.quantization)
if self.quantization in me_quant.DEPRECATED_QUANTIZATION_METHODS:
@@ -1813,8 +1818,6 @@ def _resolve_auto_dtype(
*,
is_pooling_model: bool,
):
from vllm.platforms import current_platform
supported_dtypes = [
dtype
for dtype in current_platform.supported_dtypes

View File

@@ -152,7 +152,6 @@ class ParallelConfig:
- "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
- "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
@@ -166,6 +165,9 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
enable_elastic_ep: bool = False
"""Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
@@ -245,6 +247,34 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
"""
_stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless DP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
It is a list of list[int], with each inner list contains a set of 3 ports
to be used for setting up the stateless CPU/device/TCPStore groups
in StatelessGroupCoordinator. The number of inner lists is equal to
the number of DP groups,
i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
"""
_stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EP groups when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
"""
_stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless EPLB groups when enable_elastic_ep is True.
Same topology as EP but separate NCCL communicator to avoid deadlocks.
"""
_stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
"""List of open ports for stateless world group when enable_elastic_ep is True.
Set to be private as it's not intended to be configured by users.
len(self._stateless_world_group_port_list) == 1,
"""
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
@@ -310,6 +340,13 @@ class ParallelConfig:
f"but found: {self._api_process_rank}"
)
if self.all2all_backend == "pplx":
logger.warning(
"The 'pplx' all2all backend has been removed. "
"Falling back to 'allgather_reducescatter'."
)
self.all2all_backend = "allgather_reducescatter"
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
@@ -396,7 +433,67 @@ class ParallelConfig:
return answer
def stateless_init_dp_group(self) -> ProcessGroup:
def allocate_elastic_ep_ports(self) -> None:
"""Allocate all ports for elastic EP (stateless groups + DP master).
Must be called AFTER ray.init() so that ports claimed by Ray's
idle worker pool are already in use and won't be returned by
get_open_ports_list().
"""
if not self.enable_elastic_ep:
return
if self._stateless_world_group_port_list:
return
num_world_groups = 1
dp_size = self.data_parallel_size
ep_size = self.data_parallel_size * self.world_size_across_dp
num_dp_groups = max(1, self.world_size_across_dp // dp_size)
num_ep_groups = max(1, self.world_size_across_dp // ep_size)
num_eplb_groups = num_ep_groups
total_stateless_ports = (
num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
) * 3
num_dp_master_ports = 5
all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
all_ports = all_ports[:-num_dp_master_ports]
self._stateless_world_group_port_list = [
all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
]
start_idx = num_world_groups * 3
self._stateless_dp_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
]
start_idx += num_dp_groups * 3
self._stateless_ep_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
]
start_idx += num_ep_groups * 3
self._stateless_eplb_group_port_list = [
all_ports[i : i + 3]
for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
]
def get_next_stateless_world_group_port(self) -> list[int]:
return self._stateless_world_group_port_list.pop()
def get_next_stateless_dp_group_port(self) -> list[int]:
return self._stateless_dp_group_port_list.pop()
def get_next_stateless_ep_group_port(self) -> list[int]:
return self._stateless_ep_group_port_list.pop()
def get_next_stateless_eplb_group_port(self) -> list[int]:
return self._stateless_eplb_group_port_list.pop()
def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
@@ -420,7 +517,8 @@ class ParallelConfig:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend=current_platform.dist_backend,
backend="gloo",
return_store=return_store,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
@@ -442,7 +540,6 @@ class ParallelConfig:
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (
@@ -556,6 +653,21 @@ class ParallelConfig:
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size
if self.enable_elastic_ep:
if not self.enable_eplb:
raise ValueError("Elastic EP is only supported with enable_eplb=True.")
if self.pipeline_parallel_size > 1:
raise ValueError(
"Elastic EP is not supported with pipeline parallelism "
f"(pipeline_parallel_size={self.pipeline_parallel_size})."
)
if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
raise NotImplementedError(
"Elastic EP is not compatible with data_parallel_external_lb "
"or data_parallel_hybrid_lb. Elastic EP relies on a single API "
"server and core client to coordinate scale up/down."
)
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
@@ -568,9 +680,12 @@ class ParallelConfig:
"Set data_parallel_rank to %d automatically.",
self.data_parallel_rank,
)
if not self.enable_elastic_ep:
if not self._data_parallel_master_port_list:
self._data_parallel_master_port_list = get_open_ports_list(5)
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
self.data_parallel_master_port = (
self._data_parallel_master_port_list.pop()
)
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
@@ -597,7 +712,7 @@ class ParallelConfig:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
if self.distributed_executor_backend is None and self.world_size > 1:
if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
@@ -659,6 +774,17 @@ class ParallelConfig:
"backend is mp, uni or external_launcher."
)
if (
self.all2all_backend in ("allgather_reducescatter", "naive")
and self.eplb_config.use_async
):
logger.warning(
"Async EPLB causes hangs with the '%s' all2all backend. "
"Forcing synchronous EPLB.",
self.all2all_backend,
)
self.eplb_config.use_async = False
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import copy
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
@@ -77,12 +78,24 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
enable_multi_layers_mtp: bool = False
"""If set to True, the MTP method will run multiple layers of MTP
speculator. If set to False, it will run only one layer of MTP speculator.
This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
draft_pipeline_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of pipeline parallelism for the draft model.
Defaults to the target model's pipeline parallel size. Set this to 1 to
run the drafter locally on the last target PP stage."""
tensor_parallel_size: int | None = None
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
warn users when they mistakenly provide the wrong argument."""
pipeline_parallel_size: int | None = None
"""Users should pass "draft_pipeline_parallel_size". This parameter's
purpose is to warn users when they mistakenly provide the wrong argument."""
# Draft model configuration
quantization: me_quant.QuantizationMethods | None = None
@@ -181,9 +194,22 @@ class SpeculativeConfig:
the final hidden states.
"""
factors: list[Any] = []
# Eagle3 affects the computation graph because it returns intermediate
# hidden states in addition to the final hidden state.
factors.append(self.method == "eagle3")
# Eagle3 and extract_hidden_states affect the computation graph because
# they return intermediate hidden states in addition to the final hidden state.
uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
factors.append(uses_aux_hidden_states)
# The specific layers used also affect the computation graph
if uses_aux_hidden_states and self.draft_model_config is not None:
layer_ids = getattr(
self.draft_model_config.hf_config,
"eagle_aux_hidden_state_layer_ids",
None,
)
if layer_ids is not None:
# Convert to tuple to make it hashable
factors.append(tuple(layer_ids))
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@@ -352,6 +378,8 @@ class SpeculativeConfig:
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
elif self.method == "extract_hidden_states":
self.model = "extract_hidden_states"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
@@ -394,6 +422,34 @@ class SpeculativeConfig:
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
elif self.method == "extract_hidden_states":
from vllm.transformers_utils.configs.extract_hidden_states import (
ExtractHiddenStatesConfig,
)
# ExtractHiddenStatesModel is instantiated manually in load_model()
# We just need to store the target model config for KV cache shape info
self.model = "extract_hidden_states"
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if hasattr(self.draft_model_config, "hf_config"):
hf_config = self.draft_model_config.hf_config.to_dict()
elif (
isinstance(self.draft_model_config, dict)
and "hf_config" in self.draft_model_config
):
hf_config = self.draft_model_config["hf_config"]
else:
hf_config = {}
self.draft_model_config = copy.copy(self.target_model_config)
self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
self.draft_model_config.hf_config, **hf_config
)
self.update_arch_()
self.draft_parallel_config = self.target_parallel_config
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
@@ -439,7 +495,10 @@ class SpeculativeConfig:
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
if (
self.enable_multi_layers_mtp is False
and self.num_speculative_tokens > 1
):
logger.warning(
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
@@ -478,23 +537,8 @@ class SpeculativeConfig:
method=self.method,
model_type="eagle",
)
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = (
self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
self.update_arch_()
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
@@ -510,6 +554,17 @@ class SpeculativeConfig:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif (
self.method == "mtp"
and self.enable_multi_layers_mtp
and self.num_speculative_tokens > n_predict
):
logger.warning_once(
"For multi_layer_eagle, num_speculative_tokens "
"is greater than the layer_num, adjusting to "
"layer_num"
)
self.num_speculative_tokens = n_predict
elif (
self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0
@@ -555,9 +610,17 @@ class SpeculativeConfig:
)
)
self.draft_pipeline_parallel_size = (
SpeculativeConfig._verify_and_get_draft_pp(
self.target_parallel_config,
self.draft_pipeline_parallel_size,
)
)
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config, self.draft_tensor_parallel_size
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_pipeline_parallel_size,
)
)
return self
@@ -671,17 +734,61 @@ class SpeculativeConfig:
)
return speculative_draft_tensor_parallel_size
@staticmethod
def _verify_and_get_draft_pp(
target_parallel_config: ParallelConfig,
speculative_draft_pipeline_parallel_size: int | None,
) -> int:
"""
Verifies and adjusts the pipeline parallel size for a draft model
specified using speculative_draft_pipeline_parallel_size.
"""
if speculative_draft_pipeline_parallel_size is None:
return target_parallel_config.pipeline_parallel_size
if speculative_draft_pipeline_parallel_size not in (
1,
target_parallel_config.pipeline_parallel_size,
):
raise ValueError(
f"{speculative_draft_pipeline_parallel_size=} cannot be "
"other value than 1 or target model "
f"pipeline_parallel_size="
f"{target_parallel_config.pipeline_parallel_size}"
)
return speculative_draft_pipeline_parallel_size
def update_arch_(self):
"""
EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
architectures-related fields in self.draft_model_config
"""
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
speculative_draft_pipeline_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
This is mostly a copy of the target parallel config, except the tp/pp
sizes used by the draft model.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
pipeline_parallel_size=speculative_draft_pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
@@ -699,6 +806,12 @@ class SpeculativeConfig:
"'tensor_parallel_size' is not a valid argument in the "
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
)
if self.pipeline_parallel_size is not None:
raise ValueError(
"'pipeline_parallel_size' is not a valid argument in the "
"speculative_config. Please pass "
"'draft_pipeline_parallel_size' instead."
)
if self.num_speculative_tokens is None:
raise ValueError(
@@ -718,7 +831,7 @@ class SpeculativeConfig:
self.draft_parallel_config
)
eagle3_target_supported = [
aux_hidden_states_supported = [
"llama",
"qwen",
"minicpm",
@@ -729,16 +842,16 @@ class SpeculativeConfig:
"nemotron_h",
]
if (
self.method == "eagle3"
self.method in ("eagle3", "extract_hidden_states")
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
for supported_model in eagle3_target_supported
for supported_model in aux_hidden_states_supported
)
):
raise ValueError(
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
f"Got {self.target_model_config.hf_text_config.model_type=}"
f"{self.method} is only supported for {aux_hidden_states_supported}"
f" models. Got {self.target_model_config.hf_text_config.model_type=}"
)
self.verify_equal_vocab_size_if_draft_model()
return self
@@ -782,8 +895,65 @@ class SpeculativeConfig:
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def needs_partial_pp_draft_remap(
self, target_parallel_config: ParallelConfig
) -> bool:
"""Whether draft PP is smaller than target PP and needs rank remap."""
if self.draft_parallel_config is None:
return False
return (
target_parallel_config.pipeline_parallel_size
> self.draft_parallel_config.pipeline_parallel_size
)
def resolve_partial_pp_draft_rank(
self, target_parallel_config: ParallelConfig
) -> int:
"""Map a target rank to the local draft rank for partial-PP drafting.
Currently this only supports running the draft model with `draft_pp=1`
on the last target PP stage.
"""
if not self.needs_partial_pp_draft_remap(target_parallel_config):
return target_parallel_config.rank
assert self.draft_parallel_config is not None
draft_pp = self.draft_parallel_config.pipeline_parallel_size
if draft_pp != 1:
raise ValueError(
"Partial pp drafter rank remapping only supports "
"draft_pipeline_parallel_size=1 when target PP is larger."
)
target_tp = target_parallel_config.tensor_parallel_size
draft_tp = self.draft_parallel_config.tensor_parallel_size
if draft_tp != target_tp:
raise ValueError(
"Partial pp drafter rank remapping requires "
"draft_tensor_parallel_size to equal target tensor_parallel_size. "
f"Got draft_tp={draft_tp}, target_tp={target_tp}."
)
target_pp = target_parallel_config.pipeline_parallel_size
target_rank = target_parallel_config.rank
target_pp_rank = target_rank // target_tp
target_tp_rank = target_rank % target_tp
if target_pp_rank != target_pp - 1:
raise ValueError(
"Partial pp drafter should only run on the last "
f"pipeline stage, but got pp rank {target_pp_rank} / {target_pp}"
)
return target_tp_rank
def __repr__(self) -> str:
method = self.method
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
model = (
None
if method in ("ngram", "suffix", "extract_hidden_states")
else self.draft_model_config.model
)
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"

View File

@@ -126,6 +126,9 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
# tp-dp combination broken:
# https://github.com/vllm-project/vllm/issues/34458
and cfg.parallel_config.data_parallel_size == 1
# tp-pp combination broken:
# https://github.com/vllm-project/vllm/issues/35426
and cfg.parallel_config.pipeline_parallel_size == 1
)
@@ -857,7 +860,7 @@ class VllmConfig:
self.compilation_config.pass_config.fuse_gemm_comms = False
else:
# Compute SP threshold early; disable if None (model too
# small) before +rms_norm gets forced into custom_ops.
# small for SP to be beneficial).
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
@@ -880,15 +883,13 @@ class VllmConfig:
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
if self.compilation_config.pass_config.enable_sp:
if "-rms_norm" in self.compilation_config.custom_ops:
logger.warning(
"RMS norm force disabled, sequence parallelism might break"
)
else:
self.compilation_config.custom_ops.append("+rms_norm")
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
if self.compilation_config.fast_moe_cold_start is None:
if HAS_OPAQUE_TYPE:
# On torch >= 2.11 the hoisted OpaqueObject approach supersedes
# fast_moe_cold_start, so force it off.
self.compilation_config.fast_moe_cold_start = False
elif self.compilation_config.fast_moe_cold_start is None:
# resolve default behavior: try to be as safe as possible
# this config is unsafe if any spec decoding draft model has a MOE.
# We'll conservatively turn it off if we see spec decoding.
@@ -907,9 +908,9 @@ class VllmConfig:
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
"Overriding cudagraph_mode to PIECEWISE."
"Overriding cudagraph_mode to NONE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
@@ -924,6 +925,33 @@ class VllmConfig:
CUDAGraphMode.FULL_DECODE_ONLY
)
# Check if KV connector requires PIECEWISE mode for CUDA graphs
if (
self.kv_transfer_config is not None
and self.kv_transfer_config.is_kv_transfer_instance
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
# Lazy import to avoid circular dependencies
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory,
)
connector_cls = KVConnectorFactory.get_connector_class(
self.kv_transfer_config
)
if connector_cls.requires_piecewise_for_cudagraph(
self.kv_transfer_config.kv_connector_extra_config
):
logger.warning_once(
"KV connector %s requires PIECEWISE CUDA graph mode "
"due to layerwise async operations that cannot be "
"captured in CUDA graphs. "
"Overriding cudagraph_mode from %s to PIECEWISE.",
connector_cls.__name__,
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
@@ -1114,6 +1142,20 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]
def is_ixserver_connector(kv_transfer_config) -> bool:
if kv_transfer_config is not None and hasattr(
kv_transfer_config, "kv_connector"
):
connector = kv_transfer_config.kv_connector
if isinstance(connector, str):
connector_name = connector
else:
connector_name = getattr(
type(connector), "__name__", str(connector)
)
return "IxServer" in connector_name
return False
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it
@@ -1154,7 +1196,10 @@ class VllmConfig:
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
if is_ixserver_connector(self.kv_transfer_config):
pass
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
else:
need_disable_hybrid_kv_cache_manager = True
logger.warning(
"Turning off hybrid kv cache manager because "
@@ -1169,6 +1214,11 @@ class VllmConfig:
self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
else:
self.scheduler_config.disable_hybrid_kv_cache_manager = (
need_disable_hybrid_kv_cache_manager
)
elif (
self.scheduler_config.disable_hybrid_kv_cache_manager is False
and need_disable_hybrid_kv_cache_manager
@@ -1466,22 +1516,22 @@ class VllmConfig:
if compile_range_end is not None:
computed_compile_ranges_split_points.append(compile_range_end)
# # Add the compile ranges for flashinfer
# if compilation_config.pass_config.fuse_allreduce_rms:
# tp_size = self.parallel_config.tensor_parallel_size
# max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
# if max_size is not None:
# max_token_num = max_size // (
# self.model_config.get_hidden_size()
# * self.model_config.dtype.itemsize
# )
# if compile_range_end is not None and max_token_num < compile_range_end:
# computed_compile_ranges_split_points.append(max_token_num)
# else:
# logger.debug(
# "Max num batched tokens below allreduce-rms fusion threshold, "
# "allreduce-rms fusion will be enabled for all num_tokens."
# )
# Add the compile ranges for flashinfer
if compilation_config.pass_config.fuse_allreduce_rms:
tp_size = self.parallel_config.tensor_parallel_size
max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
if max_size is not None:
max_token_num = max_size // (
self.model_config.get_hidden_size()
* self.model_config.dtype.itemsize
)
if compile_range_end is not None and max_token_num < compile_range_end:
computed_compile_ranges_split_points.append(max_token_num)
else:
logger.debug(
"Max num batched tokens below allreduce-rms fusion threshold, "
"allreduce-rms fusion will be enabled for all num_tokens."
)
# Add the compile ranges for sequence parallelism
if compilation_config.pass_config.enable_sp:
@@ -1618,6 +1668,7 @@ class VllmConfig:
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, "
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
f"kv_cache_dtype={self.cache_config.cache_dtype}, "

View File

@@ -9,5 +9,5 @@ from vllm.config.utils import config
class WeightTransferConfig:
"""Configuration for weight transfer during RL training."""
backend: Literal["nccl"] = "nccl"
backend: Literal["nccl", "ipc"] = "nccl"
"""The backend to use for weight transfer."""

View File

@@ -11,7 +11,7 @@
import dataclasses
import gc
import os
from collections.abc import Callable
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
@@ -25,6 +25,7 @@ logger = init_logger(__name__)
cumem_available = False
libcudart: Any = None
try:
from vllm.cumem_allocator import (
init_module,
@@ -41,9 +42,7 @@ except ModuleNotFoundError:
init_module = None
python_create_and_map = None
python_unmap_and_release = None
CudaRTLibrary = None
lib_name = None
libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = tuple[int, int, int, int]
@@ -65,7 +64,8 @@ def unmap_and_release(allocation_handle: HandleType) -> None:
def get_pluggable_allocator(
python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
python_malloc_fn: Callable[[HandleType], None],
python_free_func: Callable[[int], HandleType],
) -> torch.cuda.memory.CUDAPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
@@ -76,8 +76,11 @@ def get_pluggable_allocator(
@contextmanager
def use_memory_pool_with_allocator(
python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
) -> None:
python_malloc_fn: Callable[[HandleType], None],
python_free_func: Callable[[int], HandleType],
) -> Iterator[
tuple[torch.cuda.memory.MemPool, torch.cuda.memory.CUDAPluggableAllocator]
]:
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
@@ -109,7 +112,7 @@ class CuMemAllocator:
not work as expected.
"""
instance: "CuMemAllocator" = None
instance: "CuMemAllocator | None" = None
default_tag: str = "default"
@staticmethod

View File

@@ -3,14 +3,13 @@
from typing import Any
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.
"""
def __init__(self, cpu_group):
assert has_pplx(), (
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install pplx_kernels."
)
super().__init__(cpu_group)
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
)
logger.debug(
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
self.rank,
self.world_size,
)
uid = (
nvshmem_get_unique_id()
if self.rank == 0
else nvshmem_alloc_empty_unique_id()
)
dist.broadcast(
uid,
src=dist.get_process_group_ranks(self.cpu_group)[0],
group=self.cpu_group,
)
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
def get_handle(self, kwargs):
import pplx_kernels as pplx # type: ignore[import-not-found]
return self.handle_cache.get_or_create(
kwargs,
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
)
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
raise NotImplementedError
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)
logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
pass
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def __init__(self, cpu_group, tcp_store_group=None):
super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
# allow_nvlink_for_low_latency_mode=True,
# allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
super().__init__(cpu_group)
super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,

View File

@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)
KiB = 1024
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
},
}
# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
# while custom_AR wins for mid-range sizes.
#
# Benchmark results (8 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
# 32K - 64K: custom_AR wins
# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
#
# Benchmark results (4 GPUs):
# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
#
# The config defines ranges where custom_AR is preferred (symm_mem disabled).
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4,
"thresholds": {
4: 2 * MiB, # 2 MB
8: 1 * MiB, # 1 MB
# Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
# NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
"custom_ar_preferred_ranges": {
4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
},
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
"""
Determine if NCCL symmetric memory allreduce should be used.
Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
- Small tensors (≤16K): Lower latency than custom_AR
- Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
Custom_AR is preferred for mid-range sizes where its P2P approach
has lower overhead than the symm_mem copy-in/copy-out pattern.
"""
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled,
)
@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
if not is_symmetric_memory_enabled():
return False
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
if threshold is not None and input_tensor.nbytes >= threshold:
return True
tensor_size = input_tensor.nbytes
custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
world_size
)
if custom_ar_range is not None:
lower_bound, upper_bound = custom_ar_range
# Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
# Use custom_AR (not symm_mem) for mid-range sizes
return tensor_size <= lower_bound or tensor_size >= upper_bound
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]

View File

@@ -30,8 +30,9 @@ class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
self.tcp_store_group = tcp_store_group
# compute some common properties
from vllm.distributed.parallel_state import (
@@ -48,12 +49,17 @@ class All2AllManagerBase:
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
# all2all communication often has separate implementations for
# intra-node and inter-node communication
if tcp_store_group is None:
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
else:
self.internode = not all(
in_the_same_node_as(tcp_store_group, source_rank=0)
)
def get_handle(self, kwargs):
# get a handle for the all2all communication,
@@ -122,11 +128,30 @@ class DeviceCommunicatorBase:
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
# Check if this is a stateless process group
from torch.distributed.distributed_c10d import _world
is_stateless = _world.pg_map.get(cpu_group, None) is None
if is_stateless:
# For stateless groups, we can't use torch.distributed methods
self.rank = cpu_group.rank()
self.world_size = cpu_group.size()
assert global_ranks is not None
assert global_world_size is not None
self.ranks = global_ranks
self.global_rank = self.ranks[self.rank]
self.global_world_size = global_world_size
self.rank_in_group = self.rank
else:
self.rank = dist.get_rank(cpu_group)
self.world_size = dist.get_world_size(cpu_group)
self.ranks = dist.get_process_group_ranks(cpu_group)
@@ -146,7 +171,7 @@ class DeviceCommunicatorBase:
use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend
self.is_ep_communicator = "ep" in unique_name
self.is_ep_communicator = unique_name.split(":")[0] == "ep"
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None
@@ -175,9 +200,7 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
@@ -263,10 +286,9 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
torch.distributed.gather(input_,
gather_list,
dst=self.ranks[dst],
group=self.device_group)
torch.distributed.gather(
input_, gather_list, dst=self.ranks[dst], group=self.device_group
)
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
@@ -292,6 +314,13 @@ class DeviceCommunicatorBase:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
return tensor
def destroy(self):
pass
@@ -360,3 +389,6 @@ class DeviceCommunicatorBase:
This is a no-op in the base class.
"""
return hidden_states
def batch_isend_irecv(self, p2p_ops: list):
raise NotImplementedError

View File

@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
and hasattr(torch.ops._C, "init_shm_manager")
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
and self._all_group_ranks_share_shm_group_name()
):
self.dist_module = _CPUSHMDistributed(self)
elif unique_name.startswith("tp") or unique_name.startswith("pp"):
logger.info(
"CPU SHM communicator disabled for group %s: ranks do not share "
"the same SHM group name, falling back to torch.distributed.",
unique_name,
)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
def _all_group_ranks_share_shm_group_name(self) -> bool:
"""
CPUSHM requires all ranks in this group to agree on one SHM group name.
This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
"""
local_name = _CPUSHMDistributed.make_group_name(self)
names: list[str] = [""] * self.world_size
torch.distributed.all_gather_object(
names,
local_name,
group=self.device_group,
)
return len(set(names)) == 1
def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group)
return input_
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
self.communicator = communicator
self.group_name = self.make_group_name(communicator)
self.handle = self._init_cpu_shm()
@staticmethod
def make_group_name(communicator: CpuCommunicator) -> str:
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
instance_identifier = f"{instance_identifier}-{unique_name}"
self.communicator = communicator
group_ranks = [str(rank) for rank in self.communicator.ranks]
group_ranks = [str(rank) for rank in communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
self.handle = self._init_cpu_shm()
return f"{instance_identifier}-{shm_group_identifier}-cpushm"
def _init_cpu_shm(self) -> int:
thread_num_tensor = torch.tensor(

View File

@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
import ixformer.distributed as ixfd
import os
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
global_ranks: list[int] | None = None,
global_world_size: int | None = None,
tcp_store_group: StatelessProcessGroup | None = None,
):
super().__init__(cpu_group, device, device_group, unique_name)
super().__init__(
cpu_group,
device,
device_group,
unique_name,
global_ranks,
global_world_size,
)
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
@@ -48,6 +59,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.use_flashinfer_allreduce = use_flashinfer_allreduce
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)
if is_symmetric_memory_enabled():
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
self.all2all_manager = NaiveAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
elif self.all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
self.all2all_manager = AgRsAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPHTAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
self.all2all_manager = DeepEPLLAll2AllManager(
self.cpu_group, tcp_store_group
)
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group, tcp_store_group
)
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
return out
if self.world_size == 1:
return input_
if self.use_vllm_comm:
# torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
pynccl_comm = self.pynccl_comm
if pynccl_comm is None or pynccl_comm.disabled:
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)
if out is None:
# fall back to the default all-reduce using PyTorch.
# this usually happens during testing.
# when we run the model, allreduce only happens for the TP
# group, where we always have either custom allreduce or pynccl.
out = input_.clone()
torch.distributed.all_reduce(out, group=self.device_group)
return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
# pynccl_comm.reduce_scatter(output, input_tensor)
torch.distributed.reduce_scatter_tensor(output,
input_tensor,
group=self.device_group)
# Perform reduce-scatter operation
ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.send(tensor, dst)
# else:
# torch.distributed.send(tensor, self.ranks[dst], self.device_group)
if self.use_vllm_comm:
ixfd.send(tensor, self.ranks[dst], self.device_group)
else:
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
# pynccl_comm = self.pynccl_comm
# if pynccl_comm is not None and not pynccl_comm.disabled:
# pynccl_comm.recv(tensor, src)
# else:
# torch.distributed.recv(tensor, self.ranks[src], self.device_group)
if self.use_vllm_comm:
ixfd.recv(tensor, self.ranks[src], self.device_group)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
"""Broadcast a tensor from source rank to all ranks."""
if self.world_size == 1:
return tensor
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.broadcast(tensor, src)
return tensor
else:
raise ValueError("No PyNCCL communicator found")
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.fi_ar_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
self.all2all_manager = None
self.all2all_manager = None # type: ignore[assignment]
def all_gatherv(
self,
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
# return self.all2all_manager.dispatch(
# hidden_states,
# topk_weights,
# topk_ids,
# is_sequence_parallel,
# extra_tensors=extra_tensors,
# )
hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
hidden_states, extra_residual, router_logits)
return hidden_states, extra_residual, router_logits
return self.all2all_manager.dispatch(
hidden_states,
topk_weights,
topk_ids,
is_sequence_parallel,
extra_tensors=extra_tensors,
)
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states,
is_sequence_parallel,
)
def batch_isend_irecv(self, p2p_ops: list):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.batch_isend_irecv(p2p_ops)
else:
raise ValueError("No PyNCCL communicator found")

View File

@@ -312,10 +312,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
if tensor.dtype in [
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
else:
nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype),
nccl_dtype,
src,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def batch_isend_irecv(self, p2p_ops: list, stream=None):
if self.disabled:
return
if stream is None:
stream = current_stream()
self.group_start()
for op in p2p_ops:
if op.op is torch.distributed.isend:
self.send(op.tensor, op.group_peer, stream)
elif op.op is torch.distributed.irecv:
self.recv(op.tensor, op.group_peer, stream)
self.group_end()

View File

View File

@@ -0,0 +1,529 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import weakref
from collections.abc import Iterable, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import P2POp
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.wrapper import reset_compile_wrapper
from vllm.config import (
CompilationMode,
set_current_vllm_config,
)
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tp_group,
)
from vllm.distributed.elastic_ep.standby_state import (
create_standby_groups,
get_standby_dp_group,
get_standby_ep_group,
pop_standby_groups,
)
from vllm.distributed.parallel_state import (
_replace_active_groups,
prepare_communication_buffer_for_model,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
logger = init_logger(__name__)
def batch_transfer_weights(
model: nn.Module,
is_sender: bool,
peer_rank: int,
dp_group: StatelessGroupCoordinator,
expert_weights: Sequence[Iterable[torch.Tensor]],
) -> None:
device_comm = dp_group.device_communicator
if device_comm is None:
raise ValueError("No device communicator found")
expert_weights_set = set()
for weight_group in expert_weights:
for weight in weight_group:
expert_weights_set.add(weight.data_ptr())
state_dict = model.state_dict()
all_params = []
for name, param in state_dict.items():
if name.endswith("expert_map"):
continue
if param.data_ptr() not in expert_weights_set:
all_params.append(param.data)
assert len(all_params) > 0
p2p_ops = []
for param in all_params:
op = object.__new__(P2POp)
if is_sender:
op.op = torch.distributed.isend
op.tensor = param
else:
op.op = torch.distributed.irecv
op.tensor = param
op.group_peer = peer_rank
p2p_ops.append(op)
device_comm.batch_isend_irecv(p2p_ops)
def broadcast_expert_mapping(
physical_to_logical: torch.Tensor | None,
num_local_physical_experts: int | None,
num_logical_experts: int | None,
dp_group: StatelessGroupCoordinator,
device: torch.device,
src_rank: int = 0,
) -> tuple[torch.Tensor, int, int]:
if dp_group.rank_in_group == src_rank:
assert physical_to_logical is not None
assert num_local_physical_experts is not None
assert num_logical_experts is not None
assert physical_to_logical.dtype == torch.int64
shape_tensor = torch.tensor(
list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
)
metadata_tensor = torch.tensor(
[num_local_physical_experts, num_logical_experts],
dtype=torch.int64,
device="cpu",
)
else:
shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
if dp_group.rank_in_group != src_rank:
assert device is not None
physical_to_logical = torch.empty(
tuple(shape_tensor.tolist()),
dtype=torch.int64,
device=device,
)
assert physical_to_logical is not None
physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
num_local_physical_experts = int(metadata_tensor[0].item())
num_logical_experts = int(metadata_tensor[1].item())
return physical_to_logical, num_local_physical_experts, num_logical_experts
class ElasticEPScalingExecutor:
def __init__(self, worker):
self.worker_ref = weakref.ref(worker)
self.reconfig_request = None
@property
def worker(self):
worker = self.worker_ref()
if worker is None:
raise RuntimeError("Worker has been garbage collected")
return worker
def execute(self, execute_method: str, *args, **kwargs):
method = getattr(self, execute_method, None)
if method is None:
raise ValueError(f"Unknown execute method: {execute_method}")
return method(*args, **kwargs)
def create_standby_groups(
self, reconfig_request: ReconfigureDistributedRequest
) -> None:
self.reconfig_request = reconfig_request
new_dp_size = reconfig_request.new_data_parallel_size
world_size = self.worker.vllm_config.parallel_config.world_size
new_world_size_across_dp = world_size * new_dp_size
updated_config = copy.copy(self.worker.vllm_config)
updated_config.parallel_config = copy.deepcopy(
self.worker.vllm_config.parallel_config
)
updated_config.parallel_config.data_parallel_size = new_dp_size
with set_current_vllm_config(updated_config):
create_standby_groups(
new_dp_size=new_dp_size,
new_world_size_across_dp=new_world_size_across_dp,
master_ip=reconfig_request.new_data_parallel_master_ip,
world_group_ports=reconfig_request.new_stateless_world_group_port_list,
dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
)
self.worker.model_runner.eep_eplb_suppressed = True
standby_ep_group = get_standby_ep_group()
assert standby_ep_group is not None
if standby_ep_group.rank == 0:
logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
# Broadcast old_dp_size to all workers in standby group
if standby_dp_group.rank_in_group < old_dp_size:
old_dp_size_tensor = torch.tensor(
[old_dp_size], dtype=torch.int64, device="cpu"
)
else:
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
old_dp_size_tensor, 0
)
num_new_workers = new_dp_size - old_dp_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Sender-receiver pairing: the first new_workers % old_dp_size
# senders get (k+1) contiguous receivers, the rest get k
# receivers.
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if dp_rank < remainder:
recv_begin = dp_rank * (num_dst_per_sender + 1)
recv_end = recv_begin + num_dst_per_sender + 1
else:
recv_begin = (
remainder * (num_dst_per_sender + 1)
+ (dp_rank - remainder) * num_dst_per_sender
)
recv_end = recv_begin + num_dst_per_sender
ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
model = self.worker.model_runner.get_model()
for new_worker_rank in sorted(ranks_to_send):
batch_transfer_weights(
model=model,
is_sender=True,
peer_rank=new_worker_rank,
dp_group=standby_dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def broadcast_expert_mapping(self) -> None:
standby_dp_group = get_standby_dp_group()
assert standby_dp_group is not None
model_config = self.worker.model_runner.model_config
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
physical_to_logical = eplb_model_state.physical_to_logical_map
num_physical_experts = physical_to_logical.shape[1]
num_local_physical_experts = num_physical_experts // get_ep_group().world_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
broadcast_expert_mapping(
physical_to_logical=physical_to_logical,
num_local_physical_experts=num_local_physical_experts,
num_logical_experts=num_logical_experts,
dp_group=standby_dp_group,
src_rank=0,
device=self.worker.device,
)
def switch_and_remove(self) -> None:
_replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
def switch_and_prepare(self) -> None:
old_dp_size = get_dp_group().world_size
old_ep_size = get_ep_group().world_size
_replace_active_groups(**pop_standby_groups())
parallel_config = self.worker.vllm_config.parallel_config
reconfig_request = self.reconfig_request
assert reconfig_request is not None
new_dp_size = reconfig_request.new_data_parallel_size
new_ep_size = get_ep_group().world_size
parallel_config.data_parallel_size = new_dp_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
# Reconfigure MoE modules with new EP size
moe_modules = [
module
for module in self.worker.model_runner.model.modules()
if (
module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE"
)
]
num_local_experts = moe_modules[0].moe_config.num_local_experts
assert all(
module.moe_config.num_local_experts == num_local_experts
for module in moe_modules
), "All MoE modules must have the same number of experts"
for module in moe_modules:
module.moe_config.num_experts = num_local_experts * new_ep_size
module.global_num_experts = module.moe_config.num_experts
tp_size = get_tp_group().world_size
is_sequence_parallel = parallel_config.use_sequence_parallel_moe
sp_size = tp_size if is_sequence_parallel else 1
module.moe_parallel_config = FusedMoEParallelConfig.make(
tp_size_=tp_size,
pcp_size_=get_pcp_group().world_size,
dp_size_=get_dp_group().world_size,
sp_size_=sp_size,
vllm_parallel_config=parallel_config,
)
module.moe_config.moe_parallel_config = module.moe_parallel_config
# Update EPLB state
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
num_physical_experts = num_local_experts * new_ep_size
num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_physical_experts - num_logical_experts
)
old_physical_to_logical = eplb_model_state.physical_to_logical_map
num_moe_layers = old_physical_to_logical.shape[0]
num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
if new_dp_size > old_dp_size:
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_experts * new_ep_size),
-1,
dtype=old_physical_to_logical.dtype,
device=old_physical_to_logical.device,
)
expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
old_physical_to_logical
)
eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
pad_size = num_physical_experts - old_num_physical_experts
if new_dp_size > old_dp_size:
assert pad_size > 0
expanded_expert_load_pass = F.pad(
eplb_model_state.expert_load_pass, (0, pad_size), value=0
)
expanded_expert_load_window = F.pad(
eplb_model_state.expert_load_window, (0, pad_size), value=0
)
eplb_model_state.expert_load_pass = expanded_expert_load_pass
eplb_model_state.expert_load_window = expanded_expert_load_window
eplb_state.num_valid_physical_experts = old_num_physical_experts
else:
assert pad_size < 0
eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
:, :num_physical_experts
]
eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
:, :, :num_physical_experts
]
eplb_state.num_valid_physical_experts = num_physical_experts
model = self.worker.model_runner.get_model()
model.expert_weights = []
with set_current_vllm_config(self.worker.vllm_config):
model.set_eplb_state(
eplb_model_state.expert_load_pass,
eplb_model_state.logical_to_physical_map,
eplb_model_state.logical_replica_count,
)
model.update_physical_experts_metadata(
num_physical_experts=num_physical_experts,
num_local_physical_experts=num_local_experts,
)
# Force re-creation of the modular kernel (and all2all manager)
# for the new EP size by resetting quant_method to base
for module in moe_modules:
if hasattr(module.quant_method, "old_quant_method"):
module.quant_method = module.quant_method.old_quant_method
module.runner = module._init_runner()
prepare_communication_buffer_for_model(self.worker.model_runner.model)
if (
self.worker.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
# NOTE(yongji): when using stock torch.compile,
# torch.compile is triggered during GPUModelRunner's load_model()
# TODO(yongji):check do we need to re-trigger torch.compile here?
# any changes to the tensor shapes in execution should already
# be handled internally by torch.compile.
backend = self.worker.vllm_config.compilation_config.init_backend(
self.worker.vllm_config
)
compilation_counter.stock_torch_compile_count += 1
self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
# release all previously captured CUDA graphs
if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
wrapper = self.worker.model_runner.model
wrapper.concrete_cudagraph_entries = {}
elif isinstance(self.worker.model_runner.model, UBatchWrapper):
raise RuntimeError("DBO is not yet supported in elastic EP")
multi_block_table = self.worker.model_runner.input_batch.block_table
saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
for bt in multi_block_table.block_tables:
saved_block_tables.append(
(bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
)
multi_block_table.clear()
# reset the compile wrapper
torch.compiler.reset()
with set_current_vllm_config(self.worker.vllm_config):
reset_compile_wrapper(self.worker.model_runner.get_model())
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
unlock_workspace()
self.worker.compile_or_warm_up_model()
lock_workspace()
for bt, (saved_gpu, saved_cpu) in zip(
multi_block_table.block_tables, saved_block_tables
):
bt.block_table.gpu.copy_(saved_gpu)
bt.block_table.cpu.copy_(saved_cpu)
def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Starting expert resharding...")
eplb_state = self.worker.model_runner.eplb_state
assert eplb_state is not None
model_config = self.worker.model_runner.model_config
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
is_async_enabled = eplb_state.is_async
eplb_state.is_async = False
if new_dp_size is None:
eplb_state.rearrange()
else:
# scale down
parallel_config = self.worker.vllm_config.parallel_config
tp_size = parallel_config.tensor_parallel_size
old_ep_size = parallel_config.data_parallel_size * tp_size
new_ep_size = new_dp_size * tp_size
rank_mapping = {
old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
for old_ep_rank in range(old_ep_size)
}
eplb_state.rearrange(rank_mapping=rank_mapping)
# NOTE(yongji): check whether we need to synchronize here
torch.cuda.synchronize()
# reset expert_rearrangement_step to ensure all ranks are synchronized
eplb_state.expert_rearrangement_step = 0
eplb_state.num_valid_physical_experts = (
eplb_model_state.physical_to_logical_map.shape[1]
)
eplb_state.is_async = is_async_enabled
self.worker.model_runner.eep_eplb_suppressed = False
if get_ep_group().rank == 0:
logger.info("[Elastic EP] Expert resharding completed")
def receive_weights(self) -> None:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
new_dp_size = dp_group.world_size
dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
# Receive old_dp_size broadcasted during transfer_weights
old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
old_dp_size = int(old_dp_size_tensor[0].item())
# Calculate which existing worker will send to this new worker
num_new_workers = new_dp_size - old_dp_size
new_worker_idx = dp_rank - old_dp_size
num_dst_per_sender = num_new_workers // old_dp_size
remainder = num_new_workers % old_dp_size
if new_worker_idx < remainder * (num_dst_per_sender + 1):
sender_rank = new_worker_idx // (num_dst_per_sender + 1)
else:
sender_rank = (
remainder
+ (new_worker_idx - remainder * (num_dst_per_sender + 1))
// num_dst_per_sender
)
model = self.worker.model_runner.get_model()
batch_transfer_weights(
model=model,
is_sender=False,
peer_rank=sender_rank,
dp_group=dp_group,
expert_weights=model.expert_weights,
)
torch.cuda.synchronize()
def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
dp_group = get_dp_group()
assert isinstance(dp_group, StatelessGroupCoordinator)
physical_to_logical, num_local_physical_experts, num_logical_experts = (
broadcast_expert_mapping(
physical_to_logical=None,
num_local_physical_experts=None,
num_logical_experts=None,
dp_group=dp_group,
src_rank=0,
device=self.worker.device,
)
)
num_moe_layers = physical_to_logical.shape[0]
new_dp_size = get_dp_group().world_size
tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
new_ep_size = new_dp_size * tp_size
expanded_physical_to_logical = torch.full(
(num_moe_layers, num_local_physical_experts * new_ep_size),
-1,
dtype=physical_to_logical.dtype,
device=physical_to_logical.device,
)
old_num_physical_experts = physical_to_logical.shape[1]
expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
return (
expanded_physical_to_logical,
num_logical_experts,
old_num_physical_experts,
)
def prepare_new_worker(self) -> None:
with set_current_vllm_config(self.worker.vllm_config):
prepare_communication_buffer_for_model(self.worker.model_runner.get_model())

View File

@@ -0,0 +1,563 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import time
import weakref
from datetime import timedelta
from typing import TYPE_CHECKING, Literal
import torch.distributed
from vllm.config import ParallelConfig
from vllm.distributed import (
sched_yield,
stateless_destroy_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.v1.engine import (
EEPNotificationType,
ReconfigureDistributedRequest,
ReconfigureRankType,
)
from vllm.v1.engine.core import DPEngineCoreProc
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.executor.abstract import Executor
logger = init_logger(__name__)
WorkerType = Literal["existing", "new", "removing"]
class ScaleUpExistingEngineState(enum.IntEnum):
WAIT_NEW_CORE_ENGINES_INIT = 0
CREATE_STANDBY_GROUPS = 1
TRANSFER_EXPERT_MAPPING = 2
WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
TRANSFER_WEIGHTS = 4
SYNC_KV_CACHE_MEMORY_SIZE = 5
SWITCH_AND_PREPARE = 6
EPLB_RESHUFFLE = 7
COMPLETE = 8
class ScaleUpNewEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class ScaleDownRemainingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
SWITCH_AND_PREPARE = 2
COMPLETE = 3
class ScaleDownRemovingEngineState(enum.IntEnum):
PREPARE = 0
EPLB_RESHUFFLE = 1
COMPLETE = 2
class _BarrierTimeoutError(RuntimeError):
"""
Exception raised for timeout
in the first stage of our two-staged
TCPStore based barrier to synchronize the
execution of all engines in the DP group.
"""
class ElasticEPScalingState:
def __init__(
self,
model_executor: "Executor",
engine_core: "DPEngineCoreProc",
vllm_config: "VllmConfig",
new_parallel_config: ParallelConfig,
worker_type: WorkerType,
scale_type: Literal["scale_up", "scale_down"],
reconfig_request: ReconfigureDistributedRequest | None = None,
):
self.model_executor_ref = weakref.ref(model_executor)
self.engine_core_ref = weakref.ref(engine_core)
self.vllm_config = vllm_config
self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
self.new_parallel_config: ParallelConfig = new_parallel_config
self.new_dp_group: torch.distributed.ProcessGroup | None = (
self.engine_core.dp_group if worker_type == "new" else None
)
self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
self.worker_type = worker_type
self.scale_type = scale_type
self.reconfig_request = reconfig_request
if scale_type == "scale_up":
self.state = (
ScaleUpNewEngineState.PREPARE
if worker_type == "new"
else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
)
else:
self.state = (
ScaleDownRemovingEngineState.PREPARE
if worker_type == "removing"
else ScaleDownRemainingEngineState.PREPARE
)
@property
def model_executor(self) -> "Executor":
model_executor = self.model_executor_ref()
if model_executor is None:
raise RuntimeError("Model executor has been garbage collected")
return model_executor
@property
def engine_core(self) -> "DPEngineCoreProc":
engine_core = self.engine_core_ref()
if engine_core is None:
raise RuntimeError("Engine core has been garbage collected")
return engine_core
def progress(self) -> bool:
if self.scale_type == "scale_up":
return (
self._progress_new_engine()
if self.worker_type == "new"
else self._progress_existing_engine()
)
return (
self._progress_removing_engine()
if self.worker_type == "removing"
else self._progress_remaining_engine()
)
def _execute_tcp_store_barrier(
self, dp_store, group_rank, group_size, barrier_id, timeout=None
):
arrival_key = f"arrival_{barrier_id}_{group_rank}"
dp_store.set(arrival_key, b"1")
start_time = time.time()
processes_arrived: set[int] = set()
while len(processes_arrived) < group_size:
if (
timeout is not None
and time.time() - start_time > timeout.total_seconds()
):
raise _BarrierTimeoutError(
f"Barrier timed out after {timeout.total_seconds()} seconds"
)
for i in range(group_size):
if i in processes_arrived:
continue
key = f"arrival_{barrier_id}_{i}"
present = dp_store.check([key])
if present:
processes_arrived.add(i)
if len(processes_arrived) < group_size:
sched_yield()
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
"""
Execute a two-staged barrier to synchronize all engines in the DP group.
Some DP EngineCores may receive the reconfiguration notifications
later than others, and already proceed to engine step (model forward)
in the busy loop.
In this case, EngineCores that already proceed to reconfiguration
should skip reconfiguration and execute model forward for one more
step, so in the next step, all EngineCores will be synchronized.
We use a two-staged barrier to achieve this. The first time each
EngineCore executes the barrier, if a timeout is reached before the
barrier completes, that means some EngineCores have already entered
engine step. The EngineCores that timed out will then proceed to
engine step, and will synchronize with the other EngineCores in the
next step with a barrier without timeout.
"""
dp_store = self.new_dp_store if use_new_group else self.old_dp_store
dp_group = self.new_dp_group if use_new_group else self.old_dp_group
assert dp_group is not None
group_rank = dp_group.rank()
group_size = dp_group.size()
barrier_id = f"eep_barrier_{barrier_name}"
sync_key = f"{barrier_id}_sync"
# TODO(yongji): figure out appropriate timeout for the barrier
timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
try:
self._execute_tcp_store_barrier(
dp_store, group_rank, group_size, barrier_id, timeout=timeout
)
torch.distributed.barrier(dp_group)
if group_rank == 0:
dp_store.delete_key(sync_key)
for i in range(group_size):
dp_store.delete_key(f"arrival_{barrier_id}_{i}")
return True
except _BarrierTimeoutError as e:
if timeout is None:
raise RuntimeError("Unexpected timeout encountered") from e
dp_store.compare_set(sync_key, "", b"1")
return False
def _progress_existing_engine(self) -> bool:
state = self.state
if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
return False
elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
# NOTE(yongji): wait for all existing workers to receive the request
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="create_standby_groups"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._create_standby_groups()
self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
return True
elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
self._transfer_expert_mapping()
self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
return True
elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
return False
elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="transfer_weights"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._transfer_weights()
self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
return True
elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
self._sync_kv_cache_memory_size()
self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
return True
elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
self._switch_and_prepare()
self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
assert self.new_dp_group is not None
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
if self.new_dp_group.rank() == 0:
self.new_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle()
self.state = ScaleUpExistingEngineState.COMPLETE
self._update_parallel_config()
return True
else:
assert self.state == ScaleUpExistingEngineState.COMPLETE
return True
def _progress_new_engine(self) -> bool:
state = self.state
assert self.new_dp_group is not None
if state == ScaleUpNewEngineState.PREPARE:
tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=self.new_dp_group,
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
self.new_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
if (
int(self.new_dp_store.get("eep_barrier_engine_count"))
< self.new_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=True, barrier_name="eplb_reshuffle"
):
return False
assert self.new_dp_group.rank() > 0
self._eplb_reshuffle()
self.state = ScaleUpNewEngineState.COMPLETE
return True
else:
assert self.state == ScaleUpNewEngineState.COMPLETE
return True
def _progress_remaining_engine(self) -> bool:
state = self.state
if state == ScaleDownRemainingEngineState.PREPARE:
self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle_before_scale_down()
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
# NOTE(yongji): currently, after EPLB reshuffle
# that redistributes experts to remaining workers, workers
# to be removed will immediately initiate shutdown;
# existing workers can no longer execute forward steps using
# the old setup. In the future, we may keep
# the removing workers alive a bit longer,
# e.g., to drain in-batch requests.
self._create_standby_groups()
self._switch_and_prepare()
self._update_parallel_config()
self.state = ScaleDownRemainingEngineState.COMPLETE
return True
else:
assert self.state == ScaleDownRemainingEngineState.COMPLETE
return True
def _progress_removing_engine(self) -> bool:
state = self.state
if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True
if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(
use_new_group=False, barrier_name="eplb_reshuffle"
):
return False
assert self.old_dp_group.rank() > 0
self._eplb_reshuffle_before_scale_down()
self._switch_and_remove()
self.state = ScaleDownRemovingEngineState.COMPLETE
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.SHUTDOWN_COMPLETE
)
self.engine_core.shutdown()
return True
else:
assert self.state == ScaleDownRemovingEngineState.COMPLETE
return True
def handle_notification(self, notification_type: EEPNotificationType):
assert self.worker_type != "new"
if (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
elif (
notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
and self.state
== ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
):
self.old_dp_store.add("eep_barrier_engine_count", 1)
self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
def is_complete(self) -> bool:
if self.scale_type == "scale_up":
return (
self.state == ScaleUpNewEngineState.COMPLETE
if self.worker_type == "new"
else self.state == ScaleUpExistingEngineState.COMPLETE
)
return (
self.state == ScaleDownRemovingEngineState.COMPLETE
if self.worker_type == "removing"
else self.state == ScaleDownRemainingEngineState.COMPLETE
)
def _create_standby_groups(self):
self.new_dp_group, self.new_dp_store = (
self.new_parallel_config.stateless_init_dp_group(return_store=True)
)
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Created standby communication groups")
def _transfer_weights(self):
assert self.reconfig_request is not None
old_dp_size = self.old_dp_group.size()
new_dp_size = self.reconfig_request.new_data_parallel_size
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Transferred weights to new workers")
def _transfer_expert_mapping(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("broadcast_expert_mapping",)
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
def _sync_kv_cache_memory_size(self):
assert self.engine_core.available_gpu_memory_for_kv_cache > 0
assert self.new_dp_group is not None
ParallelConfig.sync_kv_cache_memory_size(
self.new_dp_group,
self.engine_core.available_gpu_memory_for_kv_cache,
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] Synced KV cache memory size to new workers")
def _switch_and_prepare(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_prepare",)
)
old_dp_group = self.old_dp_group
stateless_destroy_torch_distributed_process_group(old_dp_group)
assert self.new_dp_group is not None
new_dp_group = self.new_dp_group
self.engine_core.dp_group = new_dp_group
self.engine_core.dp_rank = new_dp_group.rank()
self.engine_core.dp_store = self.new_dp_store
engines_running = int(self.engine_core.engines_running)
current_wave = self.engine_core.current_wave
step_counter = self.engine_core.step_counter
tensor = torch.tensor(
[engines_running, current_wave, step_counter],
dtype=torch.int32,
device="cpu",
)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
)
data = tensor.tolist()
self.engine_core.engines_running = bool(data[0])
self.engine_core.current_wave = int(data[1])
self.engine_core.step_counter = int(data[2])
if new_dp_group.rank() == 0:
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.RECONFIGURE_FINISHED
)
logger.info("[Elastic EP] Switched to new setup")
def _eplb_reshuffle(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("perform_eplb_reshuffle",)
)
assert self.new_dp_group is not None
if self.new_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _eplb_reshuffle_before_scale_down(self):
assert self.reconfig_request is not None
self.model_executor.collective_rpc(
"elastic_ep_execute",
args=(
"perform_eplb_reshuffle",
self.reconfig_request.new_data_parallel_size,
),
)
if self.old_dp_group.rank() == 0:
logger.info("[Elastic EP] EPLB reshuffle completed")
def _switch_and_remove(self):
self.model_executor.collective_rpc(
"elastic_ep_execute", args=("switch_and_remove",)
)
def _update_parallel_config(self):
assert self.reconfig_request is not None
reconfig_request = self.reconfig_request
parallel_config = self.vllm_config.parallel_config
parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
if (
reconfig_request.new_data_parallel_rank
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
if (
reconfig_request.new_data_parallel_rank_local
!= ReconfigureRankType.KEEP_CURRENT_RANK
):
parallel_config.data_parallel_rank_local = (
reconfig_request.new_data_parallel_rank_local
)
parallel_config.data_parallel_master_ip = (
reconfig_request.new_data_parallel_master_ip
)
parallel_config.data_parallel_master_port = (
reconfig_request.new_data_parallel_master_port
)
parallel_config._data_parallel_master_port_list = (
reconfig_request.new_data_parallel_master_port_list
)
parallel_config._stateless_world_group_port_list = (
reconfig_request.new_stateless_world_group_port_list
)
parallel_config._stateless_dp_group_port_list = (
reconfig_request.new_stateless_dp_group_port_list
)
parallel_config._stateless_ep_group_port_list = (
reconfig_request.new_stateless_ep_group_port_list
)
parallel_config._stateless_eplb_group_port_list = (
reconfig_request.new_stateless_eplb_group_port_list
)

View File

@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed.parallel_state import (
_init_stateless_group,
_node_count,
get_pp_group,
get_tp_group,
get_world_group,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
_STANDBY_WORLD: StatelessGroupCoordinator | None = None
_STANDBY_WORLD_NODE_COUNT: int | None = None
_STANDBY_DP: StatelessGroupCoordinator | None = None
_STANDBY_EP: StatelessGroupCoordinator | None = None
_STANDBY_EPLB: StatelessGroupCoordinator | None = None
def get_standby_dp_group() -> StatelessGroupCoordinator | None:
return _STANDBY_DP
def get_standby_ep_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EP
def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
return _STANDBY_EPLB
def get_standby_world_group() -> StatelessGroupCoordinator | None:
return _STANDBY_WORLD
def create_standby_groups(
new_dp_size: int,
new_world_size_across_dp: int,
master_ip: str,
world_group_ports: list[list[int]],
dp_group_ports: list[list[int]],
ep_group_ports: list[list[int]],
eplb_group_ports: list[list[int]] | None = None,
backend: str | None = None,
) -> None:
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
world_group = get_world_group()
assert isinstance(world_group, StatelessGroupCoordinator)
backend = backend or world_group.backend
standby_world_ranks = [list(range(new_world_size_across_dp))]
_STANDBY_WORLD = _init_stateless_group(
standby_world_ranks,
"world",
world_group_ports,
master_ip,
backend,
use_device_communicator=False,
)
_STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
tp_size = get_tp_group().world_size
pp_size = get_pp_group().world_size
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
_STANDBY_DP = _init_stateless_group(
standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
)
standby_ep_ranks = (
all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
)
standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
_STANDBY_EP = _init_stateless_group(
standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
)
if eplb_group_ports is not None:
_STANDBY_EPLB = _init_stateless_group(
standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
)
def pop_standby_groups() -> dict:
"""Return all standby groups and clear the standby state."""
global \
_STANDBY_WORLD, \
_STANDBY_WORLD_NODE_COUNT, \
_STANDBY_DP, \
_STANDBY_EP, \
_STANDBY_EPLB
result = dict(
world=_STANDBY_WORLD,
dp=_STANDBY_DP,
ep=_STANDBY_EP,
eplb=_STANDBY_EPLB,
node_count=_STANDBY_WORLD_NODE_COUNT,
)
_STANDBY_WORLD = None
_STANDBY_WORLD_NODE_COUNT = None
_STANDBY_DP = None
_STANDBY_EP = None
_STANDBY_EPLB = None
return result

View File

@@ -24,7 +24,6 @@ logger = init_logger(__name__)
def start_async_worker(
state: "EplbState",
rank_mapping: dict[int, int] | None = None,
is_profile: bool = False,
) -> threading.Thread:
eplb_group = get_eplb_group().device_group
@@ -45,7 +44,6 @@ def start_async_worker(
eplb_group=eplb_group,
cuda_stream=cuda_stream,
is_profile=is_profile,
rank_mapping=rank_mapping,
)
)
except Exception as exc: # pragma: no cover - diagnostic path
@@ -107,7 +105,6 @@ async def transfer_run_periodically(
eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream,
is_profile: bool = False,
rank_mapping: dict[int, int] | None = None,
) -> None:
while True:
await asyncio.to_thread(state.rearrange_event.wait)
@@ -176,7 +173,6 @@ async def transfer_run_periodically(
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)
event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event)

View File

@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
get_node_count,
in_the_same_node_as,
)
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
@@ -159,7 +160,7 @@ class EplbModelState:
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
across different dispatch methods (naive all-to-all, DeepEP).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
@@ -302,6 +303,14 @@ class EplbState:
"""
CUDA device index for the async EPLB worker thread.
"""
self.num_valid_physical_experts: int = 0
"""
Number of valid physical experts.
This is the number of physical experts that are
actually mapped to logical experts. In elastic EP,
newly started EP ranks may not have physical experts
mapped yet.
"""
if self.device.type == "cuda":
self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available():
@@ -367,9 +376,6 @@ class EplbState:
self,
model: MixtureOfExperts,
model_config: ModelConfig,
global_expert_load: torch.Tensor | None = None,
old_global_expert_indices: torch.Tensor | None = None,
rank_mapping: dict[int, int] | None = None,
):
"""
Build the initial EPLB state.
@@ -462,75 +468,15 @@ class EplbState:
)
self.expert_rearrangement_step_interval = eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %s", policy_type)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (
model.num_moe_layers,
model.num_logical_experts,
)
assert global_expert_load.dtype == torch.int64
num_replicas = model.num_physical_experts
num_groups = model.num_expert_groups
num_nodes = get_node_count()
num_gpus = ep_group.size()
if num_gpus % num_nodes != 0:
num_nodes = 1
logger.warning_once(
f"num_gpus % num_nodes != 0, "
"not using hierarchical rearrangement algorithm.\n"
f"{num_gpus=}, {num_nodes=}"
)
# Get new expert mappings
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
max_physical_slots = new_logical_to_physical_map.shape[-1]
assert max_physical_slots <= logical_to_physical_map.shape[-1]
new_logical_to_physical_map = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, logical_to_physical_map.shape[-1] - max_physical_slots),
value=-1,
)
physical_to_logical_map = new_physical_to_logical_map.to(self.device)
logical_to_physical_map.copy_(new_logical_to_physical_map)
logical_replica_count.copy_(new_logical_replica_count)
else:
new_physical_to_logical_map = None
new_logical_to_physical_map = None
new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
if global_expert_load is not None:
rearrange_expert_weights_inplace(
old_global_expert_indices,
new_physical_to_logical_map,
model.expert_weights,
ep_group,
False,
rank_mapping,
)
self.expert_rearrangement_step = 0
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
@@ -561,11 +507,12 @@ class EplbState:
recv_dst_rows=np.array([]),
),
cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map,
new_logical_to_physical_map=new_logical_to_physical_map,
new_logical_replica_count=new_logical_replica_count,
new_physical_to_logical_map=None,
new_logical_to_physical_map=None,
new_logical_replica_count=None,
)
self.model_states[model_config.compute_hash()] = model_state
self.num_valid_physical_experts = model.num_physical_experts
def step(
self,
@@ -696,8 +643,6 @@ class EplbState:
def rearrange(
self,
is_profile: bool = False,
execute_shuffle: bool = True,
global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None:
"""
@@ -707,12 +652,6 @@ class EplbState:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
execute_shuffle (bool): If `True`, execute the shuffle
in elastic expert parallel (EEP). Default is True.
global_expert_loads (list[torch.Tensor] | None): The global expert
loads when scaling is done in EEP.
List of expert loads for the main and drafter
(when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
"""
@@ -734,18 +673,12 @@ class EplbState:
"(profile)" if is_profile else "",
)
if global_expert_loads is None:
# Map the physical expert load to global logical experts
global_expert_load_windows = []
if not execute_shuffle:
num_models = torch.tensor(
[len(self.model_states)], dtype=torch.int32, device="cpu"
)
torch.distributed.broadcast(
num_models, group=get_ep_group().cpu_group, group_src=0
)
for eplb_model_state in self.model_states.values():
expert_load_window = eplb_model_state.expert_load_window[
:, :, : self.num_valid_physical_experts
]
logical_expert_load_window = torch.zeros(
self.expert_load_window_size,
eplb_model_state.model.num_moe_layers,
@@ -755,46 +688,19 @@ class EplbState:
)
logical_expert_load_window.scatter_add_(
dim=-1,
index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
.expand_as(eplb_model_state.expert_load_window)
index=eplb_model_state.physical_to_logical_map[
:, : self.num_valid_physical_experts
]
.unsqueeze(0)
.expand_as(expert_load_window)
.long(),
src=eplb_model_state.expert_load_window,
)
if not execute_shuffle:
metadata = torch.tensor(
[
eplb_model_state.model.num_moe_layers,
eplb_model_state.model.num_logical_experts,
eplb_model_state.physical_to_logical_map.shape[1],
],
dtype=torch.int32,
device="cpu",
)
torch.distributed.broadcast(
metadata, group=get_ep_group().cpu_group, group_src=0
src=expert_load_window,
)
global_expert_load_window = logical_expert_load_window.sum(dim=0)
global_expert_load_windows.append(global_expert_load_window)
# Perform all-reduce to get the expert load across all ranks for each model
global_expert_load_windows = self._allreduce_list(
global_expert_load_windows
)
if not execute_shuffle:
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# (num_moe_layers, old_num_physical_experts)
old_global_expert_indices = eplb_model_state.physical_to_logical_map
torch.distributed.broadcast(
old_global_expert_indices, group=ep_group, group_src=0
)
if not execute_shuffle:
return global_expert_load_windows
else:
assert execute_shuffle
global_expert_load_windows = global_expert_loads
global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
# TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
@@ -806,8 +712,10 @@ class EplbState:
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
cpu_group = get_ep_group().cpu_group
num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
coordinator = get_ep_group()
assert isinstance(coordinator, StatelessGroupCoordinator)
tcp_store_group = coordinator.tcp_store_group
num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_replicas = (
num_replicas // ep_group.size() * num_gpus
@@ -933,7 +841,6 @@ class EplbState:
if self.async_worker is None:
self.async_worker = start_async_worker(
self,
rank_mapping=rank_mapping,
is_profile=is_profile,
)
@@ -1089,83 +996,6 @@ class EplbState:
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None
@staticmethod
def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Receive the expert load and old placement from the master rank.
"""
ep_group = get_ep_group()
num_models = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
num_models = num_models.item()
global_expert_loads = []
old_global_expert_indices_per_model = []
for _ in range(num_models):
metadata = torch.empty(3, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
num_moe_layers, num_logical_experts, num_old_physical_experts = (
metadata.tolist()
)
global_expert_load = torch.zeros(
(num_moe_layers, num_logical_experts),
dtype=torch.int64,
device=ep_group.device,
)
all_reduce(global_expert_load, group=ep_group.device_group)
old_global_expert_indices = torch.empty(
(num_moe_layers, num_old_physical_experts),
dtype=torch.int64,
device=ep_group.device,
)
torch.distributed.broadcast(
old_global_expert_indices,
group=ep_group.device_group,
group_src=0,
)
global_expert_loads.append(global_expert_load)
old_global_expert_indices_per_model.append(old_global_expert_indices)
return global_expert_loads, old_global_expert_indices_per_model
@classmethod
def get_eep_state(
cls, parallel_config: ParallelConfig
) -> tuple[
list[torch.Tensor] | None,
list[torch.Tensor] | None,
dict[int, int] | None,
]:
num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
torch.distributed.broadcast(
num_local_physical_experts,
group=get_ep_group().cpu_group,
group_src=0,
)
num_local_physical_experts = int(num_local_physical_experts.item())
new_ep_size = get_ep_group().world_size
global_expert_loads, old_global_expert_indices_per_model = (
EplbState.recv_state()
)
# EP configuration for all models has to be the same so as eplb config
num_logical_experts = global_expert_loads[0].shape[1]
parallel_config.eplb_config.num_redundant_experts = (
num_local_physical_experts * new_ep_size - num_logical_experts
)
assert (
old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
== 0
)
old_ep_size = (
old_global_expert_indices_per_model[0].shape[1]
// num_local_physical_experts
)
rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
return (
global_expert_loads,
old_global_expert_indices_per_model,
rank_mapping,
)
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
@@ -1203,6 +1033,60 @@ class EplbState:
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
@classmethod
def from_mapping(
cls,
model: MixtureOfExperts,
model_config: ModelConfig,
device: torch.device,
parallel_config: ParallelConfig,
expanded_physical_to_logical: torch.Tensor,
num_valid_physical_experts: int,
) -> "EplbState":
eplb_state = cls(
parallel_config=parallel_config,
device=device,
)
eplb_state.add_model(
model=model,
model_config=model_config,
)
eplb_state.num_valid_physical_experts = num_valid_physical_experts
num_moe_layers = expanded_physical_to_logical.shape[0]
num_physical_experts = expanded_physical_to_logical.shape[1]
eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
logical_to_physical_map = torch.full(
(
num_moe_layers,
model.num_logical_experts,
eplb_model_state.logical_to_physical_map.shape[2],
),
-1,
dtype=torch.int64,
)
logical_replica_count = torch.zeros(
(num_moe_layers, model.num_logical_experts),
dtype=torch.int64,
)
expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
for layer_idx in range(num_moe_layers):
for phys_idx in range(num_physical_experts):
logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
if logical_idx >= 0:
replica_idx = logical_replica_count[layer_idx, logical_idx]
logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
phys_idx
)
logical_replica_count[layer_idx, logical_idx] += 1
logical_to_physical_map = logical_to_physical_map.to(device)
logical_replica_count = logical_replica_count.to(device)
eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
eplb_model_state.logical_replica_count.copy_(logical_replica_count)
return eplb_state
@dataclass
class EplbLayerState:

View File

@@ -19,6 +19,8 @@ from torch.distributed import (
get_global_rank,
)
from vllm.distributed.parallel_state import get_ep_group
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -249,10 +251,18 @@ def move_to_buffer(
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
if isinstance(get_ep_group(), StatelessGroupCoordinator):
ep_group = get_ep_group()
is_stateless = True
else:
is_stateless = False
# Pre-compute global ranks mapping
# Pre-compute global ranks mapping (only needed for non-stateless groups)
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
if not is_stateless:
rank_to_global = {
rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
}
# 2. Post sends
if send_count > 0:
@@ -284,6 +294,14 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
if is_stateless:
for w in expert_weights:
op = object.__new__(P2POp)
op.op = torch.distributed.isend
op.tensor = w[src]
op.group_peer = dst
p2p_ops.append(op)
else:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
@@ -321,6 +339,14 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
if is_stateless:
for b in expert_weights_buffers:
op = object.__new__(P2POp)
op.op = torch.distributed.irecv
op.tensor = b[dst]
op.group_peer = src
p2p_ops.append(op)
else:
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
@@ -334,10 +360,16 @@ def move_to_buffer(
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
elif p2p_ops:
if is_stateless:
ep_group.device_communicator.batch_isend_irecv(p2p_ops)
else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()

View File

@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
def clear_events(self) -> None:
raise NotImplementedError
def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
self.add_events(other.get_all_events())
return self
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism

View File

@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
"ExampleConnector",
)
KVConnectorFactory.register_connector(
"ExampleHiddenStatesConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
"ExampleHiddenStatesConnector",
)
KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",

View File

@@ -413,7 +413,20 @@ class TpKVTopology:
f"by local tensor parallel size {self.tp_size}."
)
# P TP > D TP case, return the ratio as negative
return -remote_tp_size // self.tp_size
return remote_tp_size // self.tp_size
def pp_ratio(
self,
remote_pp_size: int,
) -> int:
"""
Calculate the pipeline parallel ratio between local and remote PP.
"""
assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, (
f"Local pipline parallel size {self.tp_size} is not divisible "
f"by remote pipline parallel size {remote_pp_size} or vice versa."
)
return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size
def block_size_ratio(
self,
@@ -457,6 +470,7 @@ class TpKVTopology:
def get_target_remote_ranks(
self,
remote_tp_size: int,
remote_pp_size: int
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
@@ -464,19 +478,36 @@ class TpKVTopology:
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
if tp_ratio > 0:
return [self.tp_rank // tp_ratio]
pp_ratio = self.pp_ratio(remote_pp_size)
target_pp_rank_list = []
target_tp_rank_list = []
if self.pp_size < remote_pp_size:
for i in range(pp_ratio):
target_pp_rank_list.append(self.pp_rank * pp_ratio + i)
else:
target_pp_rank_list.append(self.pp_rank // pp_ratio)
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio = -tp_ratio
return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
if self.tp_size < remote_tp_size:
for i in range(tp_ratio):
target_tp_rank_list.append(self.tp_rank * tp_ratio + i)
else:
target_tp_rank_list.append(self.tp_rank // tp_ratio)
target_rank_list = []
for pp_rank in target_pp_rank_list:
for tp_rank in target_tp_rank_list:
target_rank = pp_rank * remote_tp_size + tp_rank
target_rank_list.append((target_rank, pp_rank, tp_rank))
return target_rank_list
def get_target_remote_ranks_from_engine_id(
self,
remote_engine_id: EngineId,
) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size)
remote_pp_size = self.remote_pp_size[remote_engine_id]
return self.get_target_remote_ranks(remote_tp_size, remote_pp_size)
def get_current_attn_backend(vllm_config: VllmConfig):

View File

@@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC):
)
return None
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
Check if this connector requires PIECEWISE CUDA graph mode.
Connectors that use asynchronous layer-by-layer operations
(wait_for_layer_load/save_kv_layer) should override this method
to return True when those operations are enabled. These operations
cannot be captured in CUDA graphs and will be skipped during replay,
causing data races. PIECEWISE mode allows Python code to execute
between graph pieces, ensuring proper synchronization.
Args:
extra_config: The kv_connector_extra_config dict from
KVTransferConfig.
Returns:
True if this connector requires PIECEWISE CUDA graph mode,
False otherwise.
"""
return False
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations

View File

@@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
The number of elements in kv_caches and layer_names should be
the same.
"""
attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> None:
"""Inject the KV cache into the layer.
@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
if isinstance(attn_metadata, dict):
inject_kv_into_layer(
kv_cache_layer,
kv_cache,
request.slot_mapping,
attn_metadata[layer_name],
)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
elif isinstance(attn_metadata, TritonAttentionMetadata):
block_idxs = slot_mapping // self._block_size
offsets = slot_mapping % self._block_size
return layer[block_idxs, :, offsets]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]

View File

@@ -0,0 +1,354 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
def extract_from_kv_cache(
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
num_tokens: int,
) -> torch.Tensor:
"""Extract data from KV cache
Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
"""
padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
# shape: [len(slot_mapping), num_heads, head_size]
return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
@dataclass
class ReqMeta:
# Request ID
req_id: str
# Request filename
filename: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
# Whether this request is a new request or partially computed already
new_req: bool
@staticmethod
def make_meta(
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool,
) -> "ReqMeta":
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = (
block_offsets.reshape((1, block_size))
+ block_ids_tensor.reshape((num_blocks, 1)) * block_size
)
slot_mapping = slot_mapping.flatten()
return ReqMeta(
req_id=req_id,
filename=filename,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
new_req=new_req,
)
@dataclass
class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta] = field(default_factory=list)
def add_request(
self,
req_id: str,
filename: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
new_req: bool = True,
) -> None:
self.requests.append(
ReqMeta.make_meta(
req_id, filename, token_ids, block_ids, block_size, new_req
)
)
class ExampleHiddenStatesConnector(KVConnectorBase_V1):
"""
Simple debug implementation of a HiddenStatesConnector.
Simply extracts the hidden states from the kv cache and stores them to disk.
Must be used in conjunction with the `extract_hidden_states` spec decoding method.
"""
@property
def prefer_cross_layer_blocks(self) -> bool:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
# Must be False so that drafter kv cache isn't merged with verifier's
return False
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._storage_path = self._kv_transfer_config.get_from_extra_config(
"shared_storage_path", "/tmp"
)
self.cache_layers: list[str] = [] # set by self.register_kv_caches
logger.info(self._kv_transfer_config)
logger.info("Shared storage path is %s", self._storage_path)
assert self._vllm_config.speculative_config is not None, (
"ExampleHiddenStatesConnector only works when using "
"'extract_hidden_states' speculative method"
)
spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
self.num_hidden_states = len(
getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
)
self._request_filenames: dict[str, str] = {}
self._active_requests: dict[str, NewRequestData] = {}
self._req_blocks: dict[str, list[int]] = {}
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, *args, **kwargs: Any) -> None:
pass # Empty implementation of abstract method
def wait_for_layer_load(self, layer_name: str) -> None:
pass # Empty implementation of abstract method
def wait_for_save(self):
pass # Empty implementation of abstract method
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
from vllm.model_executor.models.extract_hidden_states import (
CacheOnlyAttentionLayer,
)
# Filter layers to only include CacheOnlyAttentionLayers
layers = get_layers_from_vllm_config(
self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
)
self.cache_layers = list(layers.keys())
assert len(self.cache_layers) == 1, (
f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
)
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs: Any,
) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
if layer_name not in self.cache_layers:
return
from vllm.model_executor.models.extract_hidden_states import (
CacheOnlyAttentionMetadata,
)
assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
"ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
)
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
os.makedirs(self._storage_path, exist_ok=True)
for request in connector_metadata.requests:
hidden_states = extract_from_kv_cache(
kv_layer, request.slot_mapping, request.token_ids.shape[0]
)
tensors = {
"hidden_states": hidden_states.detach().cpu(),
"token_ids": request.token_ids.detach().cpu(),
}
safetensors.torch.save_file(tensors, request.filename)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
# This connector is store-only, so we don't need to load any tokens
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
# Usually used to handle allocation of new blocks for requests that are loading
# tokens from connector's external kv cache. We never load from external cache
# so this is a no-op.
assert num_external_tokens == 0, "This connector is store-only"
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = ExampleHiddenStatesConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
token_ids = new_req.prompt_token_ids or []
filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
meta.add_request(
new_req.req_id,
filename=filename,
token_ids=token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
)
self._request_filenames[new_req.req_id] = filename
self._active_requests[new_req.req_id] = new_req
self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
if req_id not in self._active_requests:
continue
new_block_ids = cached_reqs.new_block_ids[i]
cached_req = self._active_requests[req_id]
req_block_ids = self._req_blocks[req_id]
assert new_block_ids is not None
block_ids = new_block_ids[0]
req_block_ids.extend(block_ids)
filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
meta.add_request(
req_id=req_id,
filename=filename,
token_ids=cached_req.prompt_token_ids or [],
block_ids=req_block_ids,
block_size=self._block_size,
new_req=False,
)
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
req_id = request.request_id
req_filename = self._request_filenames.pop(req_id, None)
_ = self._active_requests.pop(req_id, None)
_ = self._req_blocks.pop(req_id, None)
return False, {"hidden_states_path": req_filename}
@classmethod
def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if cls is KVConnectorBase_V1:
raise TypeError(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
# NHD means we have (num_tokens, num_heads)
# HND means we have (num_heads, num_tokens)
# For now, we only support NHD layout since this keeps the
# hidden states for each token together in memory.
# HND is primarily used when sharding heads across devices.
return "NHD"

View File

@@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents):
class LMCacheConnectorV1(KVConnectorBase_V1):
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
LMCache requires PIECEWISE CUDA graph mode when layerwise
operations are enabled. The wait_for_layer_load and save_kv_layer
methods perform actual async synchronization that cannot be
captured in CUDA graphs.
"""
return extra_config.get("use_layerwise", False)
def __init__(
self,
vllm_config: "VllmConfig",

View File

@@ -173,6 +173,29 @@ class MooncakeConnector(KVConnectorBase_V1):
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
############################################################
# Class Methods
############################################################
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config is None:
logger.warning_once(
"Unable to detect current VLLM config. "
"Fallback to default kv cache layout."
)
return None
use_mla = vllm_config.model_config.use_mla
if use_mla:
# return None when we have mla
# as the layout should not matter in that case,
# which fallback to the default behavior.
return None
logger.info_once(
"MooncakeConnector setting KV cache layout to HND for better xfer performance."
)
return "HND"
############################################################
# Scheduler Side Methods
############################################################
@@ -941,7 +964,13 @@ class MooncakeConnectorWorker:
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
cache_layout = get_kv_cache_layout()
if cache_layout == "HND":
kernel_block_size = cache.shape[-2]
else:
kernel_block_size = cache.shape[-3]
assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr)
kv_data_lens.append(tensor_size_bytes)

View File

@@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1):
- Save to all connectors.
"""
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
MultiConnector requires PIECEWISE CUDA graph mode if any of its
child connectors require it.
"""
connectors_config = extra_config.get("connectors", [])
for conn_config in connectors_config:
temp_ktc = KVTransferConfig(**conn_config)
connector_cls = KVConnectorFactory.get_connector_class(temp_ktc)
child_extra_config = conn_config.get("kv_connector_extra_config", {})
if connector_cls.requires_piecewise_for_cudagraph(child_extra_config):
return True
return False
def __init__(
self,
vllm_config: "VllmConfig",

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline
import contextlib
import gc
import os
import pickle
import weakref
from collections import namedtuple
@@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol
from unittest.mock import patch
import torch
@@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
if TYPE_CHECKING:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
import ixformer.distributed as ixfd
import vllm._custom_ops as ops
@@ -327,6 +332,8 @@ class GroupCoordinator:
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"}
self_device_group = None
self_cpu_group = None
@@ -339,7 +346,7 @@ class GroupCoordinator:
with suppress_stdout():
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ixfd_group = ixfd.init_comm_with_store(device_group)
self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
@@ -372,8 +379,7 @@ class GroupCoordinator:
self.device_communicator = device_comm_cls(
cpu_group=self.cpu_group,
device=self.device,
# device_group=self.device_group,
device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group,
device_group=self.ixfd_group if use_vllm_comm else self.device_group,
unique_name=self.unique_name,
)
@@ -385,11 +391,6 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6
)
from vllm.platforms import current_platform
# self.use_custom_op_call = (
# current_platform.is_cuda_alike() or current_platform.is_tpu()
# )
self.use_custom_op_call = False
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
@@ -468,14 +469,12 @@ class GroupCoordinator:
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
# from vllm.distributed.device_communicators.cuda_communicator import (
# CudaCommunicator,
# )
from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator,
)
if self.device_communicator is not None:
# assert isinstance(self.device_communicator, CudaCommunicator)
assert isinstance(self.device_communicator, DeviceCommunicatorBase)
assert isinstance(self.device_communicator, CudaCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
@@ -608,9 +607,9 @@ class GroupCoordinator:
src=self.ranks[src],
group=self.device_group)
else:
torch.distributed.broadcast(input_,
src=self.ranks[src],
group=self.device_group)
torch.distributed.broadcast(
input_, src=self.ranks[src], group=self.device_group
)
return input_
def broadcast_object(self, obj: Any | None = None, src: int = 0):
@@ -764,10 +763,9 @@ class GroupCoordinator:
group=group,
async_op=True)
else:
handle = torch.distributed.broadcast(tensor,
src=self.ranks[src],
group=group,
async_op=True)
handle = torch.distributed.broadcast(
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
@@ -802,10 +800,8 @@ class GroupCoordinator:
async_op=True)
else:
handle = torch.distributed.broadcast(
tensor,
src=self.ranks[src],
group=group,
async_op=True)
tensor, src=self.ranks[src], group=group, async_op=True
)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
@@ -876,6 +872,10 @@ class GroupCoordinator:
if self.world_size <= 1:
return []
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -893,10 +893,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size, f"Invalid dst rank ({dst})"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.send_object(metadata_list, dst=dst)
@@ -917,6 +913,7 @@ class GroupCoordinator:
handle = torch.distributed.isend(
tensor, dst=self.ranks[dst], group=comm_group
)
if tensor.is_cuda:
tensor.record_stream(torch.cuda.current_stream(tensor.device))
handles.append(handle)
@@ -973,6 +970,11 @@ class GroupCoordinator:
]:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None, [], []
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -990,10 +992,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size, f"Invalid src rank ({src})"
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
handles: list[Handle] = []
@@ -1072,14 +1070,13 @@ class GroupCoordinator:
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
if hasattr(self, "device_group"):
# torch.distributed.destroy_process_group(self.device_group)
if self.device_group is not None:
if self.device_communicator and self.device_communicator.use_vllm_comm:
ixfd.destroy_process_group(self.device_group)
else:
torch.distributed.destroy_process_group(self.device_group)
del self.device_group
if hasattr(self, "cpu_group"):
self.device_group = None
if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
del self.cpu_group
if self.device_communicator is not None:
@@ -1094,7 +1091,6 @@ class GroupCoordinator:
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -1105,13 +1101,12 @@ class GroupCoordinator:
if self.device_communicator is not None:
return self.device_communicator.dispatch_router_logits(
hidden_states,
extra_residual,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, extra_residual, router_logits
return hidden_states, router_logits
def dispatch(
self,
@@ -1189,6 +1184,55 @@ def init_model_parallel_group(
)
def _init_stateless_group(
group_ranks: list[list[int]],
group_name: str,
group_ports: list[list[int]],
host: str,
backend: str,
use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
"""Create a StatelessGroupCoordinator with the given parameters."""
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
world = get_world_group()
return StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=world.local_rank,
torch_distributed_backend=backend,
use_device_communicator=use_device_communicator,
group_name=group_name,
host=host,
group_ports=group_ports,
global_rank=world.rank,
global_world_size=world.world_size,
)
def _replace_active_groups(
*,
world: GroupCoordinator | None,
dp: GroupCoordinator | None,
ep: GroupCoordinator | None,
eplb: GroupCoordinator | None,
node_count: int | None,
) -> None:
"""Destroy the current DP/EP/WORLD/EPLB groups and replace them.
Destruction is collective — all ranks in the old groups must call this
function together. Pass all-``None`` to tear down without replacement.
"""
global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
for group in (_DP, _EP, _WORLD, _EPLB):
if group is not None:
group.destroy()
_WORLD = world
_DP = dp
_EP = ep
_EPLB = eplb
_NODE_COUNT = node_count
_TP: GroupCoordinator | None = None
@@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
def _init_elastic_ep_world(
config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
global _WORLD, _NODE_COUNT
assert _WORLD is None, "world group already initialized"
parallel_config = config.parallel_config
global_rank = parallel_config.data_parallel_rank * world_size + rank
global_world_size = parallel_config.world_size_across_dp
all_ranks = list(range(global_world_size))
group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
if global_rank in all_ranks:
group_ranks = [all_ranks]
group_ports = [parallel_config.get_next_stateless_world_group_port()]
world = StatelessGroupCoordinator(
group_ranks=group_ranks,
local_rank=local_rank,
torch_distributed_backend=backend,
use_device_communicator=False,
group_name="world",
host=parallel_config.data_parallel_master_ip,
group_ports=group_ports,
global_rank=global_rank,
global_world_size=global_world_size,
)
assert parallel_config.nnodes_within_dp == 1, (
"Elastic EP is not supported with multi-node TP/PP"
)
_NODE_COUNT = _node_count(world.tcp_store_group)
_WORLD = world
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
@@ -1305,6 +1382,7 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none()
enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1312,6 +1390,7 @@ def init_distributed_environment(
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
and not enable_elastic_ep
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
@@ -1365,6 +1444,18 @@ def init_distributed_environment(
rank=rank,
timeout=timeout,
)
if enable_elastic_ep:
tp_pp_cpu_group = torch.distributed.new_group(
backend="gloo", timeout=timeout
)
if _node_count(tp_pp_cpu_group) > 1:
# NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
# to initialize all DP/EP groups, hence all ranks within TP/PP group
# must reside on the same node
raise RuntimeError(
"Elastic EP is not yet supported with multi-node TP/PP"
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -1373,6 +1464,9 @@ def init_distributed_environment(
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
if enable_elastic_ep:
_init_elastic_ep_world(config, local_rank, backend, rank, world_size)
return
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
@@ -1436,16 +1530,33 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config_or_none
from vllm.config import get_current_vllm_config
config = get_current_vllm_config_or_none()
if config is not None:
config = get_current_vllm_config()
data_parallel_size = config.parallel_config.data_parallel_size
enable_elastic_ep = config.parallel_config.enable_elastic_ep
if enable_elastic_ep:
# Use stateless world group for global information
world_size = get_world_group().world_size
rank = get_world_group().rank
backend = backend or "nccl"
tp_pp_pcp_size = (
tensor_model_parallel_size
* pipeline_model_parallel_size
* prefill_context_model_parallel_size
)
local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
pipeline_model_parallel_size,
prefill_context_model_parallel_size,
tensor_model_parallel_size,
)
else:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group
)
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
@@ -1469,7 +1580,9 @@ def initialize_model_parallel(
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
@@ -1488,6 +1601,11 @@ def initialize_model_parallel(
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = local_all_ranks.reshape(
-1, decode_context_model_parallel_size
).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
@@ -1504,6 +1622,13 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(1, 2)
.reshape(-1, prefill_context_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
@@ -1515,6 +1640,13 @@ def initialize_model_parallel(
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
group_ranks = (
local_all_ranks.transpose(0, 2)
.reshape(-1, pipeline_model_parallel_size)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp"
)
@@ -1523,6 +1655,19 @@ def initialize_model_parallel(
assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
parallel_config = config.parallel_config
dp_ports = [
parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
]
_DP = _init_stateless_group(
group_ranks,
"dp",
dp_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_DP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="dp"
)
@@ -1530,7 +1675,7 @@ def initialize_model_parallel(
global _EP
assert _EP is None, "expert parallel group is already initialized"
# Don't create EP group for dense models.
if config is None or config.model_config is None or config.model_config.is_moe:
if config.model_config is None or config.model_config.is_moe:
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
@@ -1542,6 +1687,19 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
if enable_elastic_ep:
parallel_config = config.parallel_config
ep_ports = [
parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
]
_EP = _init_stateless_group(
group_ranks,
"ep",
ep_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="ep"
)
@@ -1557,9 +1715,24 @@ def initialize_model_parallel(
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
# Reuse the same group_ranks from EP
if enable_elastic_ep:
eplb_ports = [
parallel_config.get_next_stateless_eplb_group_port()
for _ in group_ranks
]
_EPLB = _init_stateless_group(
group_ranks,
"eplb",
eplb_ports,
parallel_config.data_parallel_master_ip,
backend,
)
else:
_EPLB = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="eplb"
group_ranks,
get_world_group().local_rank,
backend,
group_name="eplb",
)
# If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None
@@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
world_group = get_world_group()
if hasattr(world_group, "backend"):
backend = backend or world_group.backend
else:
backend = backend or torch.distributed.get_backend(world_group.device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size,

View File

@@ -0,0 +1,322 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.distributed import Backend, ProcessGroup
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
from vllm.distributed.parallel_state import (
GroupCoordinator,
TensorMetadata,
_get_unique_name,
_register_group,
_split_tensor_dict,
)
from vllm.distributed.utils import (
StatelessProcessGroup,
stateless_destroy_torch_distributed_process_group,
stateless_init_torch_distributed_process_group,
)
from vllm.logger import init_logger
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class StatelessGroupCoordinator(GroupCoordinator):
"""
A stateless version of the GroupCoordinator class in parallel_state,
It will create CPU, device and TCPStore based communication groups
that are independent of PyTorch's WORLD group. Hence,
communication groups with a different set of participants GPUs
can be created without destroying the existing ones.
"""
def __init__(
self,
group_ranks: list[list[int]],
local_rank: int,
torch_distributed_backend: str | Backend,
use_device_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: str | None = None,
host: str = "127.0.0.1",
group_ports: list[list[int]] | None = None,
global_rank: int = 0,
global_world_size: int = 1,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)
self.rank = global_rank
self.local_rank = local_rank
self_device_group = None
self_cpu_group = None
self_tcp_store_group = None
from vllm.platforms import current_platform
backend = str(torch_distributed_backend)
self.backend = backend
assert group_ports is not None, "group_ports is not provided"
for idx, ranks in enumerate(group_ranks):
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
ports = group_ports[idx]
device_port = ports[0]
cpu_port = ports[1]
tcp_store_port = ports[2]
device_group = stateless_init_torch_distributed_process_group(
host=host,
port=device_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend=backend,
group_name=f"{self.unique_name}_device",
)
cpu_group = stateless_init_torch_distributed_process_group(
host=host,
port=cpu_port,
rank=self.rank_in_group,
world_size=self.world_size,
backend="gloo",
group_name=f"{self.unique_name}_cpu",
)
tcp_store_group = StatelessProcessGroup.create(
host=host,
port=tcp_store_port,
rank=self.rank_in_group,
world_size=self.world_size,
)
self_device_group = device_group
self_cpu_group = cpu_group
self_tcp_store_group = tcp_store_group
assert self_cpu_group is not None
assert self_device_group is not None
assert self_tcp_store_group is not None
self.cpu_group = self_cpu_group
self.device_group = self_device_group
self.tcp_store_group = self_tcp_store_group
if current_platform.is_cuda_alike():
self.device = torch.device(f"cuda:{local_rank}")
elif current_platform.is_xpu():
self.device = torch.device(f"xpu:{local_rank}")
elif current_platform.is_out_of_tree():
self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
else:
self.device = torch.device("cpu")
self.use_device_communicator = use_device_communicator
self.device_communicator = None
if use_device_communicator and self.world_size > 1:
device_comm_cls = resolve_obj_by_qualname(
current_platform.get_device_communicator_cls()
)
assert device_comm_cls == CudaCommunicator
self.device_communicator = CudaCommunicator(
cpu_group=self.cpu_group,
device=self.device,
device_group=self.device_group,
unique_name=self.unique_name,
global_ranks=self.ranks,
global_world_size=global_world_size,
tcp_store_group=self.tcp_store_group,
)
self.mq_broadcaster = None
self.use_custom_op_call = (
current_platform.is_cuda_alike() or current_platform.is_tpu()
)
self.use_cpu_custom_send_recv = False
def destroy(self):
if self.device_communicator:
self.device_communicator.destroy()
if self.device_group:
stateless_destroy_torch_distributed_process_group(self.device_group)
if self.cpu_group:
stateless_destroy_torch_distributed_process_group(self.cpu_group)
def size(self) -> int:
"""Return the world size of this group."""
return self.world_size
def broadcast(self, input_: torch.Tensor, src: int = 0):
if self.world_size == 1:
return input_
if self.device_communicator and input_.is_cuda:
return self.device_communicator.broadcast(input_, src)
else:
return self.tcp_store_group.broadcast(input_, src)
def broadcast_object(self, obj=None, src: int = 0):
if self.world_size == 1:
return obj
return self.tcp_store_group.broadcast_obj(obj, src)
def broadcast_object_list(
self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
):
assert src < self.world_size
if self.world_size == 1:
return obj_list
if self.rank_in_group == src:
for obj in obj_list:
self.tcp_store_group.broadcast_obj(obj, src)
else:
for i in range(len(obj_list)):
obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
return obj_list
def broadcast_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any] | None = None,
src: int = 0,
group: ProcessGroup | None = None,
metadata_group: ProcessGroup | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if self.rank_in_group == src:
assert isinstance(tensor_dict, dict), (
f"Expecting a dictionary, got {type(tensor_dict)}"
)
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
else:
metadata_list = None
tensor_list = []
recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
metadata_list, src
)
if self.rank_in_group != src:
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(
value.size, dtype=value.dtype, device=value.device
)
tensor_list.append(tensor)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
tensor.copy_(self.device_communicator.broadcast(tensor, src))
else:
tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
return tensor_dict
def send_object(self, obj, dst: int) -> None:
assert dst < self.world_size
assert dst != self.rank_in_group
self.tcp_store_group.send_obj(obj, dst)
def recv_object(self, src: int):
assert src < self.world_size
assert src != self.rank_in_group
return self.tcp_store_group.recv_obj(src)
def send_tensor_dict(
self,
tensor_dict: dict[str, torch.Tensor | Any],
dst: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return tensor_dict
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
assert dst < self.world_size
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.tcp_store_group.send_obj(metadata_list, dst)
for tensor in tensor_list:
if tensor.numel() == 0:
continue
if self.device_communicator and tensor.is_cuda:
self.device_communicator.send(tensor, dst)
else:
self.tcp_store_group.send(tensor, dst)
return None
def recv_tensor_dict(
self,
src: int | None = None,
all_gather_group: Optional["GroupCoordinator"] = None,
all_gather_tensors: dict[str, bool] | None = None,
) -> dict[str, torch.Tensor | Any] | None:
if self.world_size == 1:
return None
if src is None:
src = (self.rank_in_group - 1) % self.world_size
assert src < self.world_size
recv_metadata_list = self.tcp_store_group.recv_obj(src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
if tensor.numel() > 0:
if self.device_communicator and tensor.is_cuda:
tensor = self.device_communicator.recv(
tensor.size(), tensor.dtype, src
)
else:
tensor = self.tcp_store_group.recv(tensor, src)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict
def barrier(self):
self.tcp_store_group.barrier()
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
if self.world_size == 1:
return input_
if self.device_communicator is None:
raise ValueError("No device communicator found")
if self.rank_in_group == dst:
gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
gathered_list[self.rank_in_group] = input_
for src_rank in range(self.world_size):
if src_rank != self.rank_in_group:
gathered_list[src_rank] = self.device_communicator.recv(
input_.size(), input_.dtype, src_rank
)
return torch.cat(gathered_list, dim=dim)
else:
self.device_communicator.send(input_, dst)
return None

View File

@@ -18,7 +18,7 @@ from datetime import timedelta
from typing import Any
import torch
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed import ProcessGroup, Store, TCPStore
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Broadcast a tensor from source rank to all other ranks."""
if self.rank == src:
tensor_bytes = pickle.dumps(tensor)
self.expire_data()
key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
self.store.set(key, tensor_bytes)
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return tensor
else:
key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
tensor = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
return tensor
def send(self, tensor: torch.Tensor, dst: int):
"""Send a tensor to a destination rank."""
self.expire_data()
key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(tensor))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
"""Receive a tensor from a source rank."""
key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
received = pickle.loads(self.store.get(key))
self.recv_src_counter[src] += 1
tensor.copy_(received)
return tensor
def all_reduce(
self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
) -> torch.Tensor:
"""All-reduce a tensor across all ranks."""
tensors = self.all_gather_obj(tensor)
result = tensors[0].clone()
for t in tensors[1:]:
if op == torch.distributed.ReduceOp.SUM:
result.add_(t)
elif op == torch.distributed.ReduceOp.PRODUCT:
result.mul_(t)
elif op == torch.distributed.ReduceOp.MAX:
result = torch.maximum(result, t)
elif op == torch.distributed.ReduceOp.MIN:
result = torch.minimum(result, t)
return result
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
@@ -448,8 +497,14 @@ def init_gloo_process_group(
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int, backend: str
) -> ProcessGroup:
host: str,
port: int,
rank: int,
world_size: int,
backend: str,
group_name: str | None = None,
return_store: bool = False,
) -> ProcessGroup | tuple[ProcessGroup, Store]:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
try:
if backend == "gloo":
pg = init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
else:
from vllm.platforms import current_platform
return current_platform.stateless_init_device_torch_dist_pg(
pg = current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
except NotImplementedError:
# If platform doesn't implement stateless_init_device_torch_dist_pg, it
# will raise a NotImplementedError. In this case, we fall back to gloo.
return init_gloo_process_group(
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
if group_name is not None:
from torch._C._distributed_c10d import _register_process_group
pg._set_group_name(group_name)
_register_process_group(group_name, pg)
if return_store:
return pg, store
else:
return pg
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:

View File

@@ -3,7 +3,7 @@
"""Base class for weight transfer engines."""
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Iterator
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Generic, TypeVar
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
This should be called when the worker is shutting down.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | Any,
) -> None:
"""
Send weights from trainer to inference workers.
This is a static method that can be called from the trainer process
to send weights to all inference workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
The tensors should be on the appropriate device for the backend.
trainer_args: Dictionary containing backend-specific arguments needed
to send weights. The structure depends on the backend:
- NCCL: Contains 'group', 'src', 'packed', etc.
- IPC: Contains 'mode' ('http' or 'ray'),
'llm_handle' (for Ray), 'url' (for HTTP), etc.
Example:
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> engine.trainer_send_weights(param_iter, trainer_args)
"""
raise NotImplementedError

View File

@@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine(
"vllm.distributed.weight_transfer.nccl_engine",
"NCCLWeightTransferEngine",
)
WeightTransferEngineFactory.register_engine(
"ipc",
"vllm.distributed.weight_transfer.ipc_engine",
"IPCWeightTransferEngine",
)

View File

@@ -0,0 +1,291 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""IPC-based weight transfer engine using CUDA IPC for communication."""
import base64
import pickle
from collections.abc import Callable, Iterator
from dataclasses import asdict, dataclass
from typing import Any
import requests
import torch
from torch.multiprocessing.reductions import reduce_tensor
from vllm.config.parallel import ParallelConfig
from vllm.config.weight_transfer import WeightTransferConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferEngine,
WeightTransferInitInfo,
WeightTransferUpdateInfo,
)
@dataclass
class IPCTrainerSendWeightsArgs:
"""Arguments for IPC trainer_send_weights method."""
mode: str
"""Transport mode: 'http' or 'ray'."""
llm_handle: Any = None
"""Ray ObjectRef to LLM handle (required for 'ray' mode)."""
url: str | None = None
"""Base URL for HTTP endpoint (required for 'http' mode)."""
def __post_init__(self):
"""Validate that required arguments are provided for the selected mode."""
if self.mode == "ray" and self.llm_handle is None:
raise ValueError("llm_handle is required for 'ray' mode")
if self.mode == "http" and self.url is None:
raise ValueError("url is required for 'http' mode")
if self.mode not in ("ray", "http"):
raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
@dataclass
class IPCWeightTransferInitInfo(WeightTransferInitInfo):
"""Initialization info for IPC weight transfer backend. No init needed for IPC."""
pass
@dataclass
class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for IPC weight transfer backend.
Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
it is unpickled into ``ipc_handles`` during ``__post_init__``.
"""
names: list[str]
dtype_names: list[str]
shapes: list[list[int]]
ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
"""IPC handles mapping physical GPU UUID to (func, args) tuple.
Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
ipc_handles_pickled: str | None = None
"""Base64-encoded pickled IPC handles, used for HTTP transport."""
def __post_init__(self):
if self.ipc_handles_pickled is not None:
if self.ipc_handles is not None:
raise ValueError(
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
)
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
self.ipc_handles_pickled = None
if self.ipc_handles is None:
raise ValueError(
"Either `ipc_handles` or `ipc_handles_pickled` must be provided"
)
num_params = len(self.names)
if len(self.dtype_names) != num_params:
raise ValueError(
f"`dtype_names` should be of the same size as `names`: "
f"got {len(self.dtype_names)} and {len(self.names)}"
)
if len(self.shapes) != num_params:
raise ValueError(
f"`shapes` should be of the same size as `names`: "
f"got {len(self.shapes)} and {len(self.names)}"
)
if len(self.ipc_handles) != num_params:
raise ValueError(
f"`ipc_handles` should be of the same size as `names`: "
f"got {len(self.ipc_handles)} and {len(self.names)}"
)
class IPCWeightTransferEngine(
WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo]
):
"""
Weight transfer engine using CUDA IPC for communication between trainer and workers.
This implementation uses CUDA IPC to transfer weights from the trainer (rank 0)
to all inference workers in a process group. IPC handles are used to share
memory between processes on the same node.
"""
# Define backend-specific dataclass types
init_info_cls = IPCWeightTransferInitInfo
update_info_cls = IPCWeightTransferUpdateInfo
def __init__(
self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
"""
Initialize the IPC weight transfer engine.
Args:
config: The configuration for the weight transfer engine
parallel_config: The configuration for the parallel setup
"""
super().__init__(config, parallel_config)
def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
"""
Initialize the weight transfer mechanism.
This is called once at the beginning of training.
No initialization needed for IPC backend.
Args:
init_info: IPC initialization info (empty)
"""
pass
def receive_weights(
self,
update_info: IPCWeightTransferUpdateInfo,
load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
"""
Receive weights from the trainer via CUDA IPC handles.
Args:
update_info: IPC update info containing parameter names, dtypes, shapes,
and IPC handles. Each IPC handle is a mapping between physical
GPU UUID and the IPC handle tuple (func, args).
load_weights: Callable that loads weights into the model. Called
incrementally for each weight to avoid OOM.
"""
assert update_info.ipc_handles is not None
weights = []
for name, _dtype_name, _shape, ipc_handle in zip(
update_info.names,
update_info.dtype_names,
update_info.shapes,
update_info.ipc_handles,
):
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
physical_gpu_id = str(props.uuid)
if physical_gpu_id not in ipc_handle:
raise ValueError(
f"IPC handle not found for GPU UUID {physical_gpu_id}. "
f"Available UUIDs: {list(ipc_handle.keys())}"
)
handle = ipc_handle[physical_gpu_id]
func, args = handle
list_args = list(args) # type: ignore
# Index 6 is the device_index parameter in torch's
# IPC handle tuple (rebuild_cuda_tensor). Update it
# to the current device since the logical index can
# differ between sender and receiver.
list_args[6] = device_index
weight = func(*list_args) # type: ignore
weights.append((name, weight))
load_weights(weights)
def shutdown(self) -> None:
"""
Shutdown the weight transfer engine.
"""
pass
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
) -> None:
"""
Send weights from trainer to inference workers via CUDA IPC.
Supports two modes:
- 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
- 'http': Sends weights via HTTP POST to a vLLM HTTP server
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples.
Tensors should be on the same GPU as the inference workers.
trainer_args: Dictionary containing IPC-specific arguments.
Should contain keys from IPCTrainerSendWeightsArgs:
- mode: 'ray' or 'http'
- llm_handle: Ray ObjectRef (for 'ray' mode)
- url: Base URL string (for 'http' mode)
Example (Ray mode):
>>> from vllm.distributed.weight_transfer.ipc_engine import (
... IPCWeightTransferEngine,
... IPCTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
Example (HTTP mode):
>>> args = IPCTrainerSendWeightsArgs(
... mode="http", url="http://localhost:8000"
... )
>>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
"""
# Parse trainer args - accept either dict or dataclass instance
if isinstance(trainer_args, dict):
args = IPCTrainerSendWeightsArgs(**trainer_args)
else:
args = trainer_args
# Get physical GPU UUID
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
gpu_uuid = str(props.uuid)
# Collect weight metadata and create IPC handles
names = []
dtype_names = []
shapes = []
ipc_handles = []
for name, tensor in iterator:
names.append(name)
dtype_names.append(str(tensor.dtype).split(".")[-1])
shapes.append(list(tensor.shape))
# Create IPC handle for this weight tensor
# The tensor must remain in memory for IPC to work
weight = tensor.detach().contiguous()
ipc_handle = reduce_tensor(weight)
ipc_handles.append({gpu_uuid: ipc_handle})
# Send weights based on mode
if args.mode == "ray":
# Ray mode: send via Ray RPC
import ray
update_info = asdict(
IPCWeightTransferUpdateInfo(
names=names,
dtype_names=dtype_names,
shapes=shapes,
ipc_handles=ipc_handles,
)
)
ray.get(
args.llm_handle.update_weights.remote(dict(update_info=update_info))
)
elif args.mode == "http":
# HTTP mode: send via HTTP POST with pickled handles
# Pickle and base64 encode IPC handles for HTTP transmission
pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
"utf-8"
)
url = f"{args.url}/update_weights"
payload = {
"update_info": {
"names": names,
"dtype_names": dtype_names,
"shapes": shapes,
"ipc_handles_pickled": pickled_handles,
}
}
response = requests.post(url, json=payload, timeout=300)
response.raise_for_status()

View File

@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
world_size: int
@dataclass
class NCCLTrainerSendWeightsArgs:
"""Arguments for NCCL trainer_send_weights method."""
group: Any
"""Process group (PyNcclCommunicator) for NCCL communication."""
src: int = 0
"""Source rank (default 0, trainer is typically rank 0)."""
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
"""Optional function to apply to each (name, tensor) pair before broadcasting.
If None, extracts just the tensor."""
packed: bool = False
"""Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
stream: torch.cuda.Stream | None = None
"""CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
Must match the value used in NCCLWeightTransferUpdateInfo."""
@dataclass
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for NCCL weight transfer backend."""
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
"""Size in bytes for each packed tensor buffer. Default is 1GB.
"""Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
group: Any,
src: int = 0,
post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
| None = None,
packed: bool = False,
stream: torch.cuda.Stream | None = None,
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
) -> None:
"""Broadcast weights from trainer to vLLM workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples
group: Process group (PyNcclCommunicator)
src: Source rank (default 0, trainer is typically rank 0)
post_iter_func: Optional function to apply to each (name, tensor) pair
before broadcasting. If None, extracts just the tensor.
packed: Whether to use packed tensor broadcasting for efficiency.
When True, multiple tensors are batched together before
broadcasting to reduce NCCL communication overhead.
stream: CUDA stream to use for broadcasting if packed is False.
If packed is True, new streams will be created for each buffer.
packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
Must match the value used in NCCLWeightTransferUpdateInfo.
packed_num_buffers: Number of buffers for double/triple buffering.
Must match the value used in NCCLWeightTransferUpdateInfo.
trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
NCCL-specific arguments. If a dict, should contain keys from
NCCLTrainerSendWeightsArgs.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
... NCCLTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
>>> NCCLWeightTransferEngine.trainer_send_weights(
... param_iter, group, packed=True
... )
>>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
>>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
"""
if post_iter_func is None:
# Parse trainer args - accept either dict or dataclass instance
if isinstance(trainer_args, dict):
args = NCCLTrainerSendWeightsArgs(**trainer_args)
else:
args = trainer_args
if args.post_iter_func is None:
# Default: extract just the tensor from (name, tensor) tuple
post_iter_func = lambda x: x[1]
else:
post_iter_func = args.post_iter_func
if packed:
if args.packed:
# Use packed tensor broadcasting for efficiency
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_producer,
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
packed_broadcast_producer(
iterator=iterator,
group=group,
src=src,
group=args.group,
src=args.src,
post_iter_func=post_iter_func,
buffer_size_bytes=packed_buffer_size_bytes,
num_buffers=packed_num_buffers,
buffer_size_bytes=args.packed_buffer_size_bytes,
num_buffers=args.packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for item in iterator:
tensor = post_iter_func(item)
group.broadcast(
tensor, src=src, stream=stream or torch.cuda.current_stream()
args.group.broadcast(
tensor,
src=args.src,
stream=args.stream or torch.cuda.current_stream(),
)
@staticmethod

View File

@@ -419,6 +419,7 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
@@ -896,6 +897,9 @@ class EngineArgs:
"--ubatch-size",
**parallel_kwargs["ubatch_size"],
)
parallel_group.add_argument(
"--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
)
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"],
@@ -1321,6 +1325,7 @@ class EngineArgs:
"launched vLLM.",
self.seed,
)
return ModelConfig(
model=self.model,
model_weights=self.model_weights,
@@ -1697,6 +1702,7 @@ class EngineArgs:
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
enable_elastic_ep=self.enable_elastic_ep,
enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
@@ -1905,6 +1911,7 @@ class EngineArgs:
performance_mode=self.performance_mode,
weight_transfer_config=self.weight_transfer_config,
)
return config
def _check_feature_supported(self):
@@ -2074,20 +2081,19 @@ class EngineArgs:
)
# Disable chunked prefill and prefix caching for:
# POWER (ppc64le)/RISCV CPUs in V1
# RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC,
CpuArchEnum.RISCV,
):
logger.info(
"Chunked prefill is not supported for POWER, "
"and RISC-V CPUs; "
"Chunked prefill is not supported for"
"RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
"Prefix caching is not supported for POWER, "
"and RISC-V CPUs; "
"Prefix caching is not supported for "
"RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False
@@ -2181,14 +2187,10 @@ class AsyncEngineArgs(EngineArgs):
"--enable-log-requests",
action=argparse.BooleanOptionalAction,
default=AsyncEngineArgs.enable_log_requests,
help="Enable logging requests.",
)
parser.add_argument(
"--disable-log-requests",
action=argparse.BooleanOptionalAction,
default=not AsyncEngineArgs.enable_log_requests,
help="[DEPRECATED] Disable logging requests.",
deprecated=True,
help="Enable logging request information, dependant on log level:\n"
"- INFO: Request ID, parameters and LoRA request.\n"
"- DEBUG: Prompt inputs (e.g: text, token IDs).\n"
"You can set the minimum log level via `VLLM_LOGGING_LEVEL`.",
)
current_platform.pre_register_and_update(parser)
return parser

View File

@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.anthropic.protocol import (
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicError,
AnthropicErrorResponse,
AnthropicMessagesRequest,
@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
@router.post(
"/v1/messages",
dependencies=[Depends(validate_json_request)],
@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages:
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
handler = messages(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/messages/count_tokens",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse},
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
},
)
@load_aware_call
@with_cancellation
async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request):
handler = messages(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
error = base_server.create_error_response(
message="The model does not support Messages API"
)
return translate_error_response(error)
try:
response = await handler.count_tokens(request, raw_request)
except Exception as e:
logger.exception("Error in count_tokens: %s", e)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
content=AnthropicErrorResponse(
error=AnthropicError(
type="internal_error",
message=str(e),
)
).model_dump(),
)
if isinstance(response, ErrorResponse):
return translate_error_response(response)
return JSONResponse(content=response.model_dump(exclude_none=True))
def attach_router(app: FastAPI):
app.include_router(router)

View File

@@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel):
class AnthropicContentBlock(BaseModel):
"""Content block in message"""
type: Literal["text", "image", "tool_use", "tool_result"]
type: Literal["text", "image", "tool_use", "tool_result", "thinking"]
text: str | None = None
# For image content
source: dict[str, Any] | None = None
@@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel):
input: dict[str, Any] | None = None
content: str | list[dict[str, Any]] | None = None
is_error: bool | None = None
# For thinking content
thinking: str | None = None
signature: str | None = None
class AnthropicMessage(BaseModel):
@@ -74,7 +77,7 @@ class AnthropicTool(BaseModel):
class AnthropicToolChoice(BaseModel):
"""Tool Choice definition"""
type: Literal["auto", "any", "tool"]
type: Literal["auto", "any", "tool", "none"]
name: str | None = None
@model_validator(mode="after")
@@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel):
class AnthropicDelta(BaseModel):
"""Delta for streaming responses"""
type: Literal["text_delta", "input_json_delta"] | None = None
type: (
Literal["text_delta", "input_json_delta", "thinking_delta", "signature_delta"]
| None
) = None
text: str | None = None
thinking: str | None = None
partial_json: str | None = None
signature: str | None = None
# Message delta
stop_reason: (
@@ -167,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context):
if not self.id:
self.id = f"msg_{int(time.time() * 1000)}"
class AnthropicContextManagement(BaseModel):
"""Context management information for token counting."""
original_input_tokens: int
class AnthropicCountTokensRequest(BaseModel):
"""Anthropic messages.count_tokens request"""
model: str
messages: list[AnthropicMessage]
system: str | list[AnthropicContentBlock] | None = None
tool_choice: AnthropicToolChoice | None = None
tools: list[AnthropicTool] | None = None
@field_validator("model")
@classmethod
def validate_model(cls, v):
if not v:
raise ValueError("Model is required")
return v
class AnthropicCountTokensResponse(BaseModel):
"""Anthropic messages.count_tokens response"""
input_tokens: int
context_management: AnthropicContextManagement | None = None

View File

@@ -8,6 +8,7 @@
import json
import logging
import time
import uuid
from collections.abc import AsyncGenerator
from typing import Any
@@ -16,6 +17,9 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicContentBlock,
AnthropicContextManagement,
AnthropicCountTokensRequest,
AnthropicCountTokensResponse,
AnthropicDelta,
AnthropicError,
AnthropicMessagesRequest,
@@ -85,14 +89,52 @@ class AnthropicServingMessages(OpenAIServingChat):
"tool_calls": "tool_use",
}
@staticmethod
def _convert_image_source_to_url(source: dict[str, Any]) -> str:
"""Convert an Anthropic image source to an OpenAI-compatible URL.
Anthropic supports two image source types:
- base64: {"type": "base64", "media_type": "image/jpeg", "data": "..."}
- url: {"type": "url", "url": "https://..."}
For base64 sources, this constructs a proper data URI that
downstream processors (e.g. vLLM's media connector) can handle.
"""
source_type = source.get("type")
if source_type == "url":
return source.get("url", "")
# Default to base64 processing if type is "base64"
# or missing, ensuring a proper data URI is always
# constructed for non-URL sources.
media_type = source.get("media_type", "image/jpeg")
data = source.get("data", "")
return f"data:{media_type};base64,{data}"
@classmethod
def _convert_anthropic_to_openai_request(
self, anthropic_request: AnthropicMessagesRequest
cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest
) -> ChatCompletionRequest:
"""Convert Anthropic message format to OpenAI format"""
openai_messages = []
openai_messages: list[dict[str, Any]] = []
cls._convert_system_message(anthropic_request, openai_messages)
cls._convert_messages(anthropic_request.messages, openai_messages)
req = cls._build_base_request(anthropic_request, openai_messages)
cls._handle_streaming_options(req, anthropic_request)
cls._convert_tool_choice(anthropic_request, req)
cls._convert_tools(anthropic_request, req)
return req
@classmethod
def _convert_system_message(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert Anthropic system message to OpenAI format"""
if not anthropic_request.system:
return
# Add system message if provided
if anthropic_request.system:
if isinstance(anthropic_request.system, str):
openai_messages.append(
{"role": "system", "content": anthropic_request.system}
@@ -104,27 +146,83 @@ class AnthropicServingMessages(OpenAIServingChat):
system_prompt += block.text
openai_messages.append({"role": "system", "content": system_prompt})
for msg in anthropic_request.messages:
@classmethod
def _convert_messages(
cls, messages: list, openai_messages: list[dict[str, Any]]
) -> None:
"""Convert Anthropic messages to OpenAI format"""
for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
# Handle complex content blocks
cls._convert_message_content(msg, openai_msg, openai_messages)
openai_messages.append(openai_msg)
@classmethod
def _convert_message_content(
cls,
msg,
openai_msg: dict[str, Any],
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert complex message content blocks"""
content_parts: list[dict[str, Any]] = []
tool_calls: list[dict[str, Any]] = []
reasoning_parts: list[str] = []
for block in msg.content:
cls._convert_block(
block,
msg.role,
content_parts,
tool_calls,
reasoning_parts,
openai_messages,
)
if reasoning_parts:
openai_msg["reasoning"] = "".join(reasoning_parts)
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls and not reasoning_parts:
return
@classmethod
def _convert_block(
cls,
block,
role: str,
content_parts: list[dict[str, Any]],
tool_calls: list[dict[str, Any]],
reasoning_parts: list[str],
openai_messages: list[dict[str, Any]],
) -> None:
"""Convert individual content block"""
if block.type == "text" and block.text:
content_parts.append({"type": "text", "text": block.text})
elif block.type == "image" and block.source:
content_parts.append(
{
"type": "image_url",
"image_url": {"url": block.source.get("data", "")},
}
)
image_url = cls._convert_image_source_to_url(block.source)
content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
elif block.type == "thinking" and block.thinking is not None:
reasoning_parts.append(block.thinking)
elif block.type == "tool_use":
# Convert tool use to function call format
cls._convert_tool_use_block(block, tool_calls)
elif block.type == "tool_result":
cls._convert_tool_result_block(block, role, openai_messages, content_parts)
@classmethod
def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None:
"""Convert tool_use block to OpenAI function call format"""
tool_call = {
"id": block.id or f"call_{int(time.time())}",
"type": "function",
@@ -134,45 +232,82 @@ class AnthropicServingMessages(OpenAIServingChat):
},
}
tool_calls.append(tool_call)
elif block.type == "tool_result":
if msg.role == "user":
@classmethod
def _convert_tool_result_block(
cls,
block,
role: str,
openai_messages: list[dict[str, Any]],
content_parts: list[dict[str, Any]],
) -> None:
"""Convert tool_result block to OpenAI format"""
if role == "user":
cls._convert_user_tool_result(block, openai_messages)
else:
tool_result_text = str(block.content) if block.content else ""
content_parts.append(
{"type": "text", "text": f"Tool result: {tool_result_text}"}
)
@classmethod
def _convert_user_tool_result(
cls, block, openai_messages: list[dict[str, Any]]
) -> None:
"""Convert user tool_result with text and image support"""
tool_text = ""
tool_image_urls: list[str] = []
if isinstance(block.content, str):
tool_text = block.content
elif isinstance(block.content, list):
text_parts: list[str] = []
for item in block.content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if item_type == "text":
text_parts.append(item.get("text", ""))
elif item_type == "image":
source = item.get("source", {})
url = cls._convert_image_source_to_url(source)
if url:
tool_image_urls.append(url)
tool_text = "\n".join(text_parts)
openai_messages.append(
{
"role": "tool",
"tool_call_id": block.tool_use_id or "",
"content": str(block.content)
if block.content
else "",
"content": tool_text or "",
}
)
else:
# Assistant tool result becomes regular text
tool_result_text = (
str(block.content) if block.content else ""
)
content_parts.append(
if tool_image_urls:
openai_messages.append(
{
"type": "text",
"text": f"Tool result: {tool_result_text}",
"role": "user",
"content": [ # type: ignore[dict-item]
{"type": "image_url", "image_url": {"url": img}}
for img in tool_image_urls
],
}
)
# Add tool calls to the message if any
if tool_calls:
openai_msg["tool_calls"] = tool_calls # type: ignore
@classmethod
def _build_base_request(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
openai_messages: list[dict[str, Any]],
) -> ChatCompletionRequest:
"""Build base ChatCompletionRequest"""
if isinstance(anthropic_request, AnthropicCountTokensRequest):
return ChatCompletionRequest(
model=anthropic_request.model,
messages=openai_messages,
)
# Add content parts if any
if content_parts:
if len(content_parts) == 1 and content_parts[0]["type"] == "text":
openai_msg["content"] = content_parts[0]["text"]
else:
openai_msg["content"] = content_parts # type: ignore
elif not tool_calls:
continue
openai_messages.append(openai_msg)
req = ChatCompletionRequest(
return ChatCompletionRequest(
model=anthropic_request.model,
messages=openai_messages,
max_tokens=anthropic_request.max_tokens,
@@ -183,19 +318,40 @@ class AnthropicServingMessages(OpenAIServingChat):
top_k=anthropic_request.top_k,
)
@classmethod
def _handle_streaming_options(
cls,
req: ChatCompletionRequest,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
) -> None:
"""Handle streaming configuration"""
if isinstance(anthropic_request, AnthropicCountTokensRequest):
return
if anthropic_request.stream:
req.stream = anthropic_request.stream
req.stream_options = StreamOptions.validate(
req.stream_options = StreamOptions.model_validate(
{"include_usage": True, "continuous_usage_stats": True}
)
@classmethod
def _convert_tool_choice(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
req: ChatCompletionRequest,
) -> None:
"""Convert Anthropic tool_choice to OpenAI format"""
if anthropic_request.tool_choice is None:
req.tool_choice = None
elif anthropic_request.tool_choice.type == "auto":
return
tool_choice_type = anthropic_request.tool_choice.type
if tool_choice_type == "auto":
req.tool_choice = "auto"
elif anthropic_request.tool_choice.type == "any":
elif tool_choice_type == "any":
req.tool_choice = "required"
elif anthropic_request.tool_choice.type == "tool":
elif tool_choice_type == "none":
req.tool_choice = "none"
elif tool_choice_type == "tool":
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
{
"type": "function",
@@ -203,9 +359,17 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
tools = []
@classmethod
def _convert_tools(
cls,
anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
req: ChatCompletionRequest,
) -> None:
"""Convert Anthropic tools to OpenAI format"""
if anthropic_request.tools is None:
return req
return
tools = []
for tool in anthropic_request.tools:
tools.append(
ChatCompletionToolsParam.model_validate(
@@ -219,10 +383,10 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
)
if req.tool_choice is None:
req.tool_choice = "auto"
req.tools = tools
return req
async def create_messages(
self,
@@ -263,23 +427,32 @@ class AnthropicServingMessages(OpenAIServingChat):
output_tokens=generator.usage.completion_tokens,
),
)
if generator.choices[0].finish_reason == "stop":
choice = generator.choices[0]
if choice.finish_reason == "stop":
result.stop_reason = "end_turn"
elif generator.choices[0].finish_reason == "length":
elif choice.finish_reason == "length":
result.stop_reason = "max_tokens"
elif generator.choices[0].finish_reason == "tool_calls":
elif choice.finish_reason == "tool_calls":
result.stop_reason = "tool_use"
content: list[AnthropicContentBlock] = [
content: list[AnthropicContentBlock] = []
if choice.message.reasoning:
content.append(
AnthropicContentBlock(
type="thinking",
thinking=choice.message.reasoning,
signature=uuid.uuid4().hex,
)
)
if choice.message.content:
content.append(
AnthropicContentBlock(
type="text",
text=generator.choices[0].message.content
if generator.choices[0].message.content
else "",
text=choice.message.content,
)
)
]
for tool_call in generator.choices[0].message.tool_calls:
for tool_call in choice.message.tool_calls:
anthropic_tool_call = AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
@@ -297,10 +470,85 @@ class AnthropicServingMessages(OpenAIServingChat):
generator: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]:
try:
class _ActiveBlockState:
def __init__(self) -> None:
self.content_block_index = 0
self.block_type: str | None = None
self.block_index: int | None = None
self.block_signature: str | None = None
self.signature_emitted: bool = False
self.tool_use_id: str | None = None
def reset(self) -> None:
self.block_type = None
self.block_index = None
self.block_signature = None
self.signature_emitted = False
self.tool_use_id = None
def start(self, block: AnthropicContentBlock) -> None:
self.block_type = block.type
self.block_index = self.content_block_index
if block.type == "thinking":
self.block_signature = uuid.uuid4().hex
self.signature_emitted = False
self.tool_use_id = None
elif block.type == "tool_use":
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = block.id
else:
self.block_signature = None
self.signature_emitted = True
self.tool_use_id = None
first_item = True
finish_reason = None
content_block_index = 0
content_block_started = False
state = _ActiveBlockState()
# Map from tool call index to tool_use_id
tool_index_to_id: dict[int, str] = {}
def stop_active_block():
events: list[str] = []
if state.block_type is None:
return events
if (
state.block_type == "thinking"
and state.block_signature is not None
and not state.signature_emitted
):
chunk = AnthropicStreamEvent(
index=state.block_index,
type="content_block_delta",
delta=AnthropicDelta(
type="signature_delta",
signature=state.block_signature,
),
)
data = chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_delta"))
state.signature_emitted = True
stop_chunk = AnthropicStreamEvent(
index=state.block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
events.append(wrap_data_with_event(data, "content_block_stop"))
state.reset()
state.content_block_index += 1
return events
def start_block(block: AnthropicContentBlock):
chunk = AnthropicStreamEvent(
index=state.content_block_index,
type="content_block_start",
content_block=block,
)
data = chunk.model_dump_json(exclude_unset=True)
event = wrap_data_with_event(data, "content_block_start")
state.start(block)
return event
async for item in generator:
if item.startswith("data:"):
@@ -326,6 +574,8 @@ class AnthropicServingMessages(OpenAIServingChat):
id=origin_chunk.id,
content=[],
model=origin_chunk.model,
stop_reason=None,
stop_sequence=None,
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
@@ -341,13 +591,8 @@ class AnthropicServingMessages(OpenAIServingChat):
# last chunk including usage info
if len(origin_chunk.choices) == 0:
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
)
data = stop_chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_stop")
for event in stop_active_block():
yield event
stop_reason = self.stop_reason_map.get(
finish_reason or "stop"
)
@@ -369,26 +614,55 @@ class AnthropicServingMessages(OpenAIServingChat):
if origin_chunk.choices[0].finish_reason is not None:
finish_reason = origin_chunk.choices[0].finish_reason
continue
# continue
# content
if origin_chunk.choices[0].delta.content is not None:
if not content_block_started:
# thinking / text content
reasoning_delta = origin_chunk.choices[0].delta.reasoning
if reasoning_delta is not None:
if reasoning_delta == "":
pass
else:
if state.block_type != "thinking":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="thinking", thinking=""
)
)
yield start_event
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
type="text", text=""
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="thinking_delta",
thinking=reasoning_delta,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
yield wrap_data_with_event(data, "content_block_delta")
if origin_chunk.choices[0].delta.content is not None:
if origin_chunk.choices[0].delta.content == "":
continue
pass
else:
if state.block_type != "text":
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(type="text", text="")
)
yield start_event
chunk = AnthropicStreamEvent(
index=content_block_index,
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="text_delta",
@@ -397,44 +671,47 @@ class AnthropicServingMessages(OpenAIServingChat):
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
continue
# tool calls
elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
tool_call = origin_chunk.choices[0].delta.tool_calls[0]
# tool calls - process all tool calls in the delta
if len(origin_chunk.choices[0].delta.tool_calls) > 0:
for tool_call in origin_chunk.choices[0].delta.tool_calls:
if tool_call.id is not None:
if content_block_started:
stop_chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_stop",
# Update mapping for incremental updates
tool_index_to_id[tool_call.index] = tool_call.id
# Only create new block if different tool call
# AND has a name
tool_name = (
tool_call.function.name
if tool_call.function
else None
)
data = stop_chunk.model_dump_json(
exclude_unset=True
)
yield wrap_data_with_event(
data, "content_block_stop"
)
content_block_started = False
content_block_index += 1
chunk = AnthropicStreamEvent(
index=content_block_index,
type="content_block_start",
content_block=AnthropicContentBlock(
if (
state.tool_use_id != tool_call.id
and tool_name is not None
):
for event in stop_active_block():
yield event
start_event = start_block(
AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name
if tool_call.function
else None,
name=tool_name,
input={},
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_start")
content_block_started = True
if tool_call.function and tool_call.function.arguments:
)
yield start_event
# Handle initial arguments if present
if (
tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_call.id
):
chunk = AnthropicStreamEvent(
index=content_block_index,
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
@@ -445,20 +722,31 @@ class AnthropicServingMessages(OpenAIServingChat):
yield wrap_data_with_event(
data, "content_block_delta"
)
else:
# Incremental update - use index to find tool_use_id
tool_use_id = tool_index_to_id.get(tool_call.index)
if (
tool_use_id is not None
and tool_call.function
and tool_call.function.arguments
and state.tool_use_id == tool_use_id
):
chunk = AnthropicStreamEvent(
index=content_block_index,
index=(
state.block_index
if state.block_index is not None
else state.content_block_index
),
type="content_block_delta",
delta=AnthropicDelta(
type="input_json_delta",
partial_json=tool_call.function.arguments
if tool_call.function
else None,
partial_json=tool_call.function.arguments,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
yield wrap_data_with_event(
data, "content_block_delta"
)
continue
else:
error_response = AnthropicStreamEvent(
@@ -481,3 +769,31 @@ class AnthropicServingMessages(OpenAIServingChat):
data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n"
async def count_tokens(
self,
request: AnthropicCountTokensRequest,
raw_request: Request | None = None,
) -> AnthropicCountTokensResponse | ErrorResponse:
"""Implements Anthropic's messages.count_tokens endpoint."""
chat_req = self._convert_anthropic_to_openai_request(request)
result = await self.render_chat_request(chat_req)
if isinstance(result, ErrorResponse):
return result
_, engine_prompts = result
input_tokens = sum( # type: ignore
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
for prompt in engine_prompts
if "prompt_token_ids" in prompt
)
response = AnthropicCountTokensResponse(
input_tokens=input_tokens,
context_management=AnthropicContextManagement(
original_input_tokens=input_tokens
),
)
return response

View File

@@ -7,6 +7,7 @@ import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from functools import cached_property, lru_cache, partial
from itertools import accumulate
from pathlib import Path
@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("video", placeholder)
@dataclass
class ChatTemplateConfig:
chat_template: str | None = None
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
trust_request_chat_template: bool = False
def validate_chat_template(chat_template: Path | str | None):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:

View File

@@ -1,12 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.mm_processor import (
BenchmarkMMProcessorSubcommand,
)
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
# Keep this package init import-free.
#
# The `vllm` console script imports `vllm.entrypoints.cli.main`, which causes
# Python to import this package before loading the `main` submodule.
# Eagerly importing benchmark subcommands here makes every `vllm serve ...`
# startup depend on optional benchmark-only modules.
#
# Benchmark subcommands are loaded on demand in
# `vllm.entrypoints.cli.benchmark.main`.
__all__: list[str] = [
"BenchmarkLatencySubcommand",
"BenchmarkMMProcessorSubcommand",
"BenchmarkServingSubcommand",
"BenchmarkStartupSubcommand",
"BenchmarkSweepSubcommand",
"BenchmarkThroughputSubcommand",
]

View File

@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import importlib
import logging
import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
@@ -15,30 +13,6 @@ if typing.TYPE_CHECKING:
else:
FlexibleArgumentParser = argparse.ArgumentParser
logger = logging.getLogger(__name__)
def _load_benchmark_subcommands() -> None:
modules = [
"vllm.entrypoints.cli.benchmark.latency",
"vllm.entrypoints.cli.benchmark.mm_processor",
"vllm.entrypoints.cli.benchmark.serve",
"vllm.entrypoints.cli.benchmark.startup",
"vllm.entrypoints.cli.benchmark.sweep",
"vllm.entrypoints.cli.benchmark.throughput",
]
for module_name in modules:
try:
importlib.import_module(module_name)
except ModuleNotFoundError as e:
logger.warning(
"Skipping benchmark subcommand module %s because an optional "
"dependency could not be imported: %r",
module_name,
e,
)
class BenchmarkSubcommand(CLISubcommand):
"""The `bench` subcommand for the vLLM CLI."""
@@ -64,8 +38,6 @@ class BenchmarkSubcommand(CLISubcommand):
)
bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type")
_load_benchmark_subcommands()
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
cmd_subparser = bench_subparsers.add_parser(
cmd_cls.name,

View File

@@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace):
num_api_servers: int = args.api_server_count
assert num_api_servers > 0
if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
# TODO(wentao): remove this once well tested
raise ValueError(
"--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
)
if num_api_servers > 1:
setup_multiprocess_prometheus()
@@ -246,8 +252,12 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager: APIServerProcessManager | None = None
from vllm.v1.engine.utils import get_engine_zmq_addresses
addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers
vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(

View File

@@ -101,11 +101,15 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
sampling_params = self._sampling_params_from_proto(
request.sampling_params, stream=request.stream
)
tokenization_kwargs = self._tokenization_kwargs_from_proto(
request.sampling_params
)
async for output in self.async_llm.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
):
# Convert vLLM output to protobuf
# For streaming, always send chunks
@@ -308,9 +312,6 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
seed=params.seed if params.HasField("seed") else None,
include_stop_str_in_output=params.include_stop_str_in_output,
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
truncate_prompt_tokens=params.truncate_prompt_tokens
if params.HasField("truncate_prompt_tokens")
else None,
structured_outputs=structured_outputs,
# detokenize must be True if stop strings are used
detokenize=bool(stop),
@@ -319,6 +320,14 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
else RequestOutputKind.FINAL_ONLY,
)
@staticmethod
def _tokenization_kwargs_from_proto(
params: vllm_engine_pb2.SamplingParams,
) -> dict[str, int] | None:
if params.HasField("truncate_prompt_tokens"):
return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
return None
@staticmethod
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
"""

View File

@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import warnings
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
import cloudpickle
@@ -41,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateConfig,
ChatTemplateContentFormatOption,
load_chat_template,
)
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.score.utils import (
ScoreData,
ScoreMultiModalParam,
@@ -146,6 +149,7 @@ class LLM:
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
chat_template: The chat template to apply.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
@@ -233,6 +237,7 @@ class LLM:
quantization: QuantizationMethods | None = None,
revision: str | None = None,
tokenizer_revision: str | None = None,
chat_template: Path | str | None = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
@@ -385,9 +390,16 @@ class LLM:
self.model_config = self.llm_engine.model_config
self.renderer = self.llm_engine.renderer
self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor
self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
self.init_pooling_io_processors = init_pooling_io_processors(
supported_tasks=supported_tasks,
model_config=self.model_config,
renderer=self.renderer,
chat_template_config=self.chat_template_config,
)
# Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None
@@ -1030,7 +1042,6 @@ class LLM:
prompts: PromptType | Sequence[PromptType] | DataPrompt,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
*,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
pooling_task: PoolingTask | None = None,
@@ -1088,21 +1099,7 @@ class LLM:
"pooling model."
)
if truncate_prompt_tokens is not None:
warnings.warn(
"The `truncate_prompt_tokens` parameter in `LLM.encode()` "
"is deprecated and will be removed in v0.16. "
"Please pass it via `tokenization_kwargs` instead.",
DeprecationWarning,
stacklevel=2,
)
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
if isinstance(prompts, dict) and "data" in prompts:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
@@ -1136,6 +1133,31 @@ class LLM:
for p in params_seq:
if p.task is None:
p.task = "plugin"
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
output_type=PoolingRequestOutput,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)
]
else:
if pooling_params is None:
# Use default pooling params.
@@ -1153,6 +1175,28 @@ class LLM:
)
raise ValueError(msg)
if pooling_task in self.init_pooling_io_processors:
io_processor = self.init_pooling_io_processors[pooling_task]
processor_inputs = io_processor.pre_process_offline(
prompts_seq, tokenization_kwargs
)
seq_lora_requests = self._lora_request_to_seq(
lora_request, len(prompts_seq)
)
seq_priority = self._priority_to_seq(None, len(prompts))
self._render_and_add_requests(
prompts=processor_inputs,
params=params_seq,
lora_requests=seq_lora_requests,
priorities=seq_priority,
)
outputs = self._run_engine(
use_tqdm=use_tqdm, output_type=PoolingRequestOutput
)
outputs = io_processor.post_process(outputs)
else:
outputs = self._run_completion(
prompts=prompts_seq,
params=params_seq,
@@ -1161,31 +1205,12 @@ class LLM:
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
if use_io_processor:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(outputs)
return [
PoolingRequestOutput[Any](
request_id="",
outputs=processed_outputs,
num_cached_tokens=getattr(
processed_outputs, "num_cached_tokens", 0
),
prompt_token_ids=[],
finished=True,
)
]
return outputs
def embed(
self,
prompts: PromptType | Sequence[PromptType],
*,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
@@ -1221,12 +1246,6 @@ class LLM:
"Try converting the model using `--convert embed`."
)
if truncate_prompt_tokens is not None:
tokenization_kwargs = merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=truncate_prompt_tokens),
)
items = self.encode(
prompts,
use_tqdm=use_tqdm,
@@ -1294,7 +1313,6 @@ class LLM:
/,
*,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
@@ -1319,13 +1337,11 @@ class LLM:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
"""
return self.encode(
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="token_classify",
tokenization_kwargs=tokenization_kwargs,
)
@@ -1771,23 +1787,15 @@ class LLM:
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
seq_priority = self._priority_to_seq(priority, len(prompts))
return self._render_and_add_requests(
prompts=(
self._preprocess_cmpl_one(prompt, tok_kwargs)
for prompt, tok_kwargs in zip(
maybe_tqdm(
seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
),
seq_tok_kwargs,
self._preprocess_cmpl_one(prompt, tokenization_kwargs)
for prompt in maybe_tqdm(
seq_prompts,
use_tqdm=use_tqdm,
desc="Rendering prompts",
)
),
params=seq_params,
@@ -1841,13 +1849,6 @@ class LLM:
seq_convs = conversation_to_seq(messages)
seq_params = self._params_to_seq(params, len(seq_convs))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
seq_tok_kwargs = [
merge_kwargs(
tokenization_kwargs,
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
)
for param in seq_params
]
return self._render_and_run_requests(
prompts=(
@@ -1859,16 +1860,13 @@ class LLM:
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
tokenization_kwargs=tok_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
for conversation, tok_kwargs in zip(
maybe_tqdm(
for conversation in maybe_tqdm(
seq_convs,
use_tqdm=use_tqdm,
desc="Rendering conversations",
),
seq_tok_kwargs,
)
),
params=seq_params,

View File

@@ -18,6 +18,20 @@ class RequestLogger:
def __init__(self, *, max_log_len: int | None) -> None:
self.max_log_len = max_log_len
if not logger.isEnabledFor(logging.INFO):
logger.warning_once(
"`--enable-log-requests` is set but "
"the minimum log level is higher than INFO. "
"No request information will be logged."
)
elif not logger.isEnabledFor(logging.DEBUG):
logger.info_once(
"`--enable-log-requests` is set but "
"the minimum log level is higher than DEBUG. "
"Only limited information will be logged to minimize overhead. "
"To view more details, set `VLLM_LOGGING_LEVEL=DEBUG`."
)
def log_inputs(
self,
request_id: str,

View File

@@ -38,6 +38,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -336,6 +337,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)
repetition_detection: RepetitionDetectionParams | None = Field(
default=None,
description="Parameters for detecting repetitive N-gram patterns "
"in output tokens. If such repetition is detected, generation will "
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
"token patterns, stopping only when they hit the maximum output length "
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
"can detect such behavior and terminate early, saving time and tokens.",
)
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
@@ -490,7 +501,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
@@ -500,8 +510,37 @@ class ChatCompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
repetition_detection=self.repetition_detection,
)
@model_validator(mode="before")
@classmethod
def validate_response_format(cls, data):
response_format = data.get("response_format")
if response_format is None:
return data
rf_type = (
response_format.get("type")
if isinstance(response_format, dict)
else getattr(response_format, "type", None)
)
if rf_type == "json_schema":
json_schema = (
response_format.get("json_schema")
if isinstance(response_format, dict)
else getattr(response_format, "json_schema", None)
)
if json_schema is None:
raise VLLMValidationError(
"When response_format type is 'json_schema', the "
"'json_schema' field must be provided.",
parameter="response_format",
)
return data
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):

View File

@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
)
# get the expected call based on partial JSON
# parsing which "autocompletes" the JSON
expected_call = json.dumps(
tool_parser.prev_tool_call_arr[index].get(
# parsing which "autocompletes" the JSON.
# Tool parsers (e.g. Qwen3Coder) store
# arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# spurious final delta.
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
),
ensure_ascii=False,
)
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool

View File

@@ -143,7 +143,8 @@ class BaseFrontendArgs:
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If set to True, log model outputs (generations).
Requires --enable-log-requests."""
Requires `--enable-log-requests`. As with `--enable-log-requests`,
information is only logged at INFO level at maximum."""
enable_log_deltas: bool = True
"""If set to False, output deltas will not be logged. Relevant only if
--enable-log-outputs is set.
@@ -277,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs):
Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM.
"""
use_gpu_for_pooling_score: bool = False
"""If set, run pooling score MaxSim on GPU in the API server process.
Can significantly improve late-interaction scoring performance.
https://github.com/vllm-project/vllm/pull/35330"""
@classmethod
def _customize_cli_kwargs(

View File

@@ -26,6 +26,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -166,6 +167,16 @@ class CompletionRequest(OpenAIBaseModel):
),
)
repetition_detection: RepetitionDetectionParams | None = Field(
default=None,
description="Parameters for detecting repetitive N-gram patterns "
"in output tokens. If such repetition is detected, generation will "
"be ended early. LLMs can sometimes generate repetitive, unhelpful "
"token patterns, stopping only when they hit the maximum output length "
"(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
"can detect such behavior and terminate early, saving time and tokens.",
)
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
@@ -259,7 +270,7 @@ class CompletionRequest(OpenAIBaseModel):
structured_outputs_kwargs["json"] = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
assert structural_tag is not None and isinstance(
assert isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
@@ -302,7 +313,6 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
@@ -311,8 +321,37 @@ class CompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
repetition_detection=self.repetition_detection,
)
@model_validator(mode="before")
@classmethod
def validate_response_format(cls, data):
response_format = data.get("response_format")
if response_format is None:
return data
rf_type = (
response_format.get("type")
if isinstance(response_format, dict)
else getattr(response_format, "type", None)
)
if rf_type == "json_schema":
json_schema = (
response_format.get("json_schema")
if isinstance(response_format, dict)
else getattr(response_format, "json_schema", None)
)
if json_schema is None:
raise VLLMValidationError(
"When response_format type is 'json_schema', the "
"'json_schema' field must be provided.",
parameter="response_format",
)
return data
@model_validator(mode="before")
@classmethod
def check_structured_outputs_count(cls, data):

View File

@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse,
TranslationRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
| TokenizeCompletionRequest
| DetokenizeRequest
| EmbeddingCompletionRequest
| ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
ChatCompletionRequest
| TokenizeChatRequest
| EmbeddingChatRequest
| ClassificationChatRequest
| PoolingChatRequest
)
@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
| ClassificationResponse
| ScoreResponse
| GenerateResponse
)
RequestT = TypeVar("RequestT", bound=AnyRequest)
@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
so you can easily tell “this ID came from Embedding vs Classification.”
A short string prepended to every requests ID (e.g. "embd")
so you can easily tell “this ID came from Embedding.”
"""
def __init__(
@@ -456,7 +447,7 @@ class OpenAIServing:
) -> ErrorResponse | None:
"""
Default preprocessing hook. Subclasses may override
to prepare `ctx` (classification, embedding, etc.).
to prepare `ctx` (embedding, etc.).
"""
return None
@@ -817,7 +808,7 @@ class OpenAIServing:
token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
# Note: EmbeddingRequest, ClassificationRequest,
# Note: EmbeddingRequest,
# and ScoreRequest doesn't have max_tokens
if isinstance(
request,
@@ -828,8 +819,6 @@ class OpenAIServing:
ScoreTextRequest,
ScoreQueriesDocumentsRequest,
RerankRequest,
ClassificationCompletionRequest,
ClassificationChatRequest,
),
):
# Note: input length can be up to the entire model context length
@@ -839,8 +828,6 @@ class OpenAIServing:
ScoreDataRequest: "score",
ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score",
ClassificationCompletionRequest: "classification",
ClassificationChatRequest: "classification",
}
operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError(

View File

@@ -328,8 +328,9 @@ class ResponsesRequest(OpenAIBaseModel):
# Also check text.format for OpenAI-style json_schema
if self.text is not None and self.text.format is not None:
if structured_outputs is not None:
raise ValueError(
"Cannot specify both structured_outputs and text.format"
raise VLLMValidationError(
"Cannot specify both structured_outputs and text.format",
parameter="structured_outputs",
)
response_format = self.text.format
if (
@@ -378,14 +379,19 @@ class ResponsesRequest(OpenAIBaseModel):
)
@model_validator(mode="before")
@classmethod
def validate_background(cls, data):
if not data.get("background"):
return data
if not data.get("store", True):
raise ValueError("background can only be used when `store` is true")
raise VLLMValidationError(
"background can only be used when `store` is true",
parameter="background",
)
return data
@model_validator(mode="before")
@classmethod
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise VLLMValidationError(
@@ -394,16 +400,19 @@ class ResponsesRequest(OpenAIBaseModel):
return data
@model_validator(mode="before")
@classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
raise ValueError(
"Parameter 'cache_salt' must be a non-empty string if provided."
raise VLLMValidationError(
"Parameter 'cache_salt' must be a non-empty string if provided.",
parameter="cache_salt",
)
return data
@model_validator(mode="before")
@classmethod
def function_call_parsing(cls, data):
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
This ensures Pydantic can properly resolve union types in the input field.

View File

@@ -85,6 +85,8 @@ from vllm.entrypoints.openai.responses.protocol import (
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseInputOutputMessage,
ResponseReasoningPartAddedEvent,
ResponseReasoningPartDoneEvent,
ResponsesRequest,
ResponsesResponse,
ResponseUsage,
@@ -1339,6 +1341,19 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
yield _increment_sequence_number_and_return(
ResponseReasoningPartAddedEvent(
type="response.reasoning_part.added",
sequence_number=-1,
output_index=current_output_index,
item_id=current_item_id,
content_index=current_content_index,
part=ResponseReasoningTextContent(
text="",
type="reasoning_text",
),
)
)
else:
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
@@ -1369,7 +1384,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
current_content_index += 1
first_delta_sent = True
# todo(kebe7jun) tool call support
@@ -1397,6 +1411,19 @@ class OpenAIServingResponses(OpenAIServing):
text=reason_content,
)
)
yield _increment_sequence_number_and_return(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=ResponseReasoningTextContent(
text=reason_content,
type="reasoning_text",
),
)
)
current_content_index = 0
reasoning_item = ResponseReasoningItem(
type="reasoning",
@@ -1418,6 +1445,8 @@ class OpenAIServingResponses(OpenAIServing):
item=reasoning_item,
)
)
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@@ -1432,8 +1461,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
current_output_index += 1
current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
ResponseContentPartAddedEvent(
type="response.content_part.added",
@@ -1449,7 +1476,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
current_content_index += 1
# reset previous delta messages
previous_delta_messages = []
@@ -1485,7 +1511,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
current_content_index += 1
previous_delta_messages.append(delta_message)
if previous_delta_messages:
@@ -1505,7 +1530,19 @@ class OpenAIServingResponses(OpenAIServing):
text=reason_content,
)
)
current_content_index += 1
yield _increment_sequence_number_and_return(
ResponseReasoningPartDoneEvent(
type="response.reasoning_part.done",
sequence_number=-1,
item_id=current_item_id,
output_index=current_output_index,
content_index=current_content_index,
part=ResponseReasoningTextContent(
text=reason_content,
type="reasoning_text",
),
)
)
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[
@@ -1543,7 +1580,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
)
)
current_content_index += 1
part = ResponseOutputText(
text=final_content,
type="output_text",
@@ -1559,7 +1595,6 @@ class OpenAIServingResponses(OpenAIServing):
part=part,
)
)
current_content_index += 1
item = ResponseOutputMessage(
type="message",
role="assistant",

View File

@@ -11,6 +11,7 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
from soundfile import LibsndfileError
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
@@ -57,6 +58,14 @@ try:
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# 1 = unrecognised format (file is not a supported audio container)
# 3 = malformed file (corrupt or structurally invalid audio)
# 4 = unsupported encoding (codec not supported by this libsndfile build)
_BAD_SF_CODES = {1, 3, 4}
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
@@ -315,9 +324,15 @@ class OpenAISpeechToText(OpenAIServing):
)
with io.BytesIO(audio_data) as bytes_:
try:
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
except LibsndfileError as exc:
# Distinguish client errors (invalid audio) from server errors
if exc.code in _BAD_SF_CODES:
raise ValueError("Invalid or unsupported audio file.") from exc
raise
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
warnings.warn(
"The 'vllm.entrypoints.openai.translations' module has been renamed to "
"'vllm.entrypoints.openai.speech_to_text'. Please update your imports. "
"This backward-compatible alias will be removed in version 0.17+.",
DeprecationWarning,
stacklevel=2,
)

Some files were not shown because too many files have changed in this diff Show More