Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
244
vllm/.gitignore
vendored
Normal file
244
vllm/.gitignore
vendored
Normal file
@@ -0,0 +1,244 @@
|
||||
# version file generated by setuptools-scm
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/__init__.py
|
||||
!vllm/vllm_flash_attn/flash_attn_interface.py
|
||||
|
||||
# OpenAI triton kernels copied from source
|
||||
vllm/third_party/triton_kernels/*
|
||||
|
||||
# FlashMLA interface copied from source
|
||||
vllm/third_party/flashmla/flash_mla_interface.py
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
cmake-build-*/
|
||||
CMakeUserPresets.json
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
/.deps/
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# generated files
|
||||
**/generated/**
|
||||
|
||||
# uv
|
||||
uv.lock
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples/*
|
||||
!docs/examples/README.md
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# Claude
|
||||
.claude/
|
||||
|
||||
# Codex
|
||||
.codex/
|
||||
|
||||
# Cursor
|
||||
.cursor/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
# Results
|
||||
*.csv
|
||||
|
||||
# Python pickle files
|
||||
*.pkl
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ignore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
|
||||
# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
|
||||
!vllm/benchmarks/lib/
|
||||
|
||||
# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
|
||||
vllm/grpc/vllm_engine_pb2.py
|
||||
vllm/grpc/vllm_engine_pb2_grpc.py
|
||||
vllm/grpc/vllm_engine_pb2.pyi
|
||||
|
||||
# Ignore generated cpu headers
|
||||
csrc/cpu/cpu_attn_dispatch_generated.h
|
||||
|
||||
107
vllm/__init__.py
Normal file
107
vllm/__init__.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
|
||||
# The version.py should be independent library, and we always import the
|
||||
# version library first. Such assumption is critical for some customization.
|
||||
from .version import __version__, __version_tuple__ # isort:skip
|
||||
|
||||
import typing
|
||||
|
||||
# The environment variables override should be imported before any other
|
||||
# modules to ensure that the environment variables are set before any
|
||||
# other modules are imported.
|
||||
import vllm.env_override # noqa: F401
|
||||
|
||||
MODULE_ATTRS = {
|
||||
"bc_linter_skip": "._bc_linter:bc_linter_skip",
|
||||
"bc_linter_include": "._bc_linter:bc_linter_include",
|
||||
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
|
||||
"EngineArgs": ".engine.arg_utils:EngineArgs",
|
||||
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
|
||||
"LLMEngine": ".engine.llm_engine:LLMEngine",
|
||||
"LLM": ".entrypoints.llm:LLM",
|
||||
"initialize_ray_cluster": ".v1.executor.ray_utils:initialize_ray_cluster",
|
||||
"PromptType": ".inputs:PromptType",
|
||||
"TextPrompt": ".inputs:TextPrompt",
|
||||
"TokensPrompt": ".inputs:TokensPrompt",
|
||||
"ModelRegistry": ".model_executor.models:ModelRegistry",
|
||||
"SamplingParams": ".sampling_params:SamplingParams",
|
||||
"PoolingParams": ".pooling_params:PoolingParams",
|
||||
"ClassificationOutput": ".outputs:ClassificationOutput",
|
||||
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
|
||||
"CompletionOutput": ".outputs:CompletionOutput",
|
||||
"EmbeddingOutput": ".outputs:EmbeddingOutput",
|
||||
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
|
||||
"PoolingOutput": ".outputs:PoolingOutput",
|
||||
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
|
||||
"RequestOutput": ".outputs:RequestOutput",
|
||||
"ScoringOutput": ".outputs:ScoringOutput",
|
||||
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
|
||||
}
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (
|
||||
ClassificationOutput,
|
||||
ClassificationRequestOutput,
|
||||
CompletionOutput,
|
||||
EmbeddingOutput,
|
||||
EmbeddingRequestOutput,
|
||||
PoolingOutput,
|
||||
PoolingRequestOutput,
|
||||
RequestOutput,
|
||||
ScoringOutput,
|
||||
ScoringRequestOutput,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.executor.ray_utils import initialize_ray_cluster
|
||||
|
||||
from ._bc_linter import bc_linter_include, bc_linter_skip
|
||||
else:
|
||||
|
||||
def __getattr__(name: str) -> typing.Any:
|
||||
from importlib import import_module
|
||||
|
||||
if name in MODULE_ATTRS:
|
||||
module_name, attr_name = MODULE_ATTRS[name].split(":")
|
||||
module = import_module(module_name, __package__)
|
||||
return getattr(module, attr_name)
|
||||
else:
|
||||
raise AttributeError(f"module {__package__} has no attribute {name}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"bc_linter_skip",
|
||||
"bc_linter_include",
|
||||
"__version_tuple__",
|
||||
"LLM",
|
||||
"ModelRegistry",
|
||||
"PromptType",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"PoolingOutput",
|
||||
"PoolingRequestOutput",
|
||||
"EmbeddingOutput",
|
||||
"EmbeddingRequestOutput",
|
||||
"ClassificationOutput",
|
||||
"ClassificationRequestOutput",
|
||||
"ScoringOutput",
|
||||
"ScoringRequestOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
"AsyncEngineArgs",
|
||||
"initialize_ray_cluster",
|
||||
"PoolingParams",
|
||||
]
|
||||
1810
vllm/_aiter_ops.py
Normal file
1810
vllm/_aiter_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
54
vllm/_bc_linter.py
Normal file
54
vllm/_bc_linter.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# vllm/_bc_linter.py
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(obj: T) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||
|
||||
|
||||
def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
No-op decorator to mark symbols/files for BC-linter suppression.
|
||||
|
||||
Usage:
|
||||
@bc_linter_skip
|
||||
def legacy_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(obj: T) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ...
|
||||
|
||||
|
||||
def bc_linter_include(obj: Any = None, *, reason: str | None = None):
|
||||
"""
|
||||
Usage:
|
||||
@bc_linter_include
|
||||
def public_api(...): ...
|
||||
"""
|
||||
|
||||
def _wrap(x: T) -> T:
|
||||
return x
|
||||
|
||||
return _wrap if obj is None else obj
|
||||
|
||||
|
||||
__all__ = ["bc_linter_skip", "bc_linter_include"]
|
||||
4238
vllm/_custom_ops.py
Normal file
4238
vllm/_custom_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
96
vllm/_oink_ops.py
Normal file
96
vllm/_oink_ops.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Small helper wrappers for external Oink Blackwell custom ops.
|
||||
|
||||
vLLM does not depend on the external Oink repository/package. When an external
|
||||
plugin registers torch.library.custom_op entrypoints under the `oink::`
|
||||
namespace (e.g. via vLLM's general_plugins mechanism) and
|
||||
`VLLM_USE_OINK_OPS=1` is set, vLLM can route eligible calls to those ops.
|
||||
|
||||
This module provides:
|
||||
- A single place to probe Oink op availability at module init time
|
||||
(outside torch.compile tracing), and
|
||||
- Thin wrappers around the torch.ops entrypoints for use in CUDA fast paths,
|
||||
without introducing graph breaks.
|
||||
|
||||
Important:
|
||||
Do not call the availability helpers in a compiled region. They may call
|
||||
functions decorated with `torch._dynamo.disable` to safely check
|
||||
conditions that should not be traced.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch._dynamo import disable as _dynamo_disable # type: ignore[attr-defined]
|
||||
except Exception: # pragma: no cover
|
||||
|
||||
def _dynamo_disable(fn: Callable): # type: ignore[misc]
|
||||
return fn
|
||||
|
||||
|
||||
def _has_oink_op(op_name: str) -> bool:
|
||||
"""Check if a specific oink op is registered."""
|
||||
return hasattr(torch.ops, "oink") and hasattr(torch.ops.oink, op_name)
|
||||
|
||||
|
||||
@_dynamo_disable
|
||||
def is_oink_available_for_device(device_index: int) -> bool:
|
||||
"""Return True if Oink ops are registered and device is SM100+.
|
||||
|
||||
This function is intended to be called during module initialization
|
||||
(e.g., in RMSNorm.__init__), not in the forward path.
|
||||
|
||||
External plugins are expected to gate registration on SM100+ and
|
||||
VLLM_USE_OINK_OPS=1, so if the ops are present they should be usable.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
major, minor = torch.cuda.get_device_capability(device_index)
|
||||
sm = 10 * major + minor
|
||||
if sm < 100:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return _has_oink_op("rmsnorm")
|
||||
|
||||
|
||||
def has_fused_add_rms_norm() -> bool:
|
||||
"""Return True if the in-place fused op is registered."""
|
||||
return _has_oink_op("fused_add_rms_norm")
|
||||
|
||||
|
||||
def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Call `torch.ops.oink.rmsnorm`.
|
||||
|
||||
This wrapper is safe to call in torch.compile regions.
|
||||
"""
|
||||
return torch.ops.oink.rmsnorm(x, weight, eps)
|
||||
|
||||
|
||||
def fused_add_rms_norm_(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
) -> None:
|
||||
"""Call `torch.ops.oink.fused_add_rms_norm` (mutates x and residual)."""
|
||||
torch.ops.oink.fused_add_rms_norm(x, residual, weight, eps)
|
||||
|
||||
|
||||
def fused_add_rms_norm(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convenience wrapper returning (x, residual) after in-place mutation."""
|
||||
fused_add_rms_norm_(x, residual, weight, eps)
|
||||
return x, residual
|
||||
159
vllm/_xpu_ops.py
Normal file
159
vllm/_xpu_ops.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def register_fake(fn):
|
||||
return lambda name: fn
|
||||
else:
|
||||
try:
|
||||
from torch.library import register_fake
|
||||
except ImportError:
|
||||
from torch.library import impl_abstract as register_fake
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "fp8_gemm_w8a16"):
|
||||
|
||||
@register_fake("_xpu_C::fp8_gemm_w8a16")
|
||||
def _fp8_gemm_w8a16_fake(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
M = input_2d.size(0)
|
||||
N = q_weight.size(1)
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
if hasattr(torch.ops._xpu_C, "int4_gemm_w4a16"):
|
||||
|
||||
@register_fake("_xpu_C::int4_gemm_w4a16")
|
||||
def _int4_gemm_w4a16_fake(
|
||||
input: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
weight_scale: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int,
|
||||
group_idx: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
M = input_2d.size(0)
|
||||
N = q_weight.size(1)
|
||||
return torch.empty((M, N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
class xpu_ops:
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
out: torch.Tensor | None = None,
|
||||
block_table: torch.Tensor | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
window_size: list[int] | None = None,
|
||||
softcap: float | None = 0.0,
|
||||
seqused_k: torch.Tensor | None = None,
|
||||
cu_seqlens_k: torch.Tensor | None = None,
|
||||
# passed in qwen vl
|
||||
dropout_p: float = 0.0,
|
||||
# The following parameters are not used in xpu kernel currently,
|
||||
# we keep API compatible to CUDA's.
|
||||
scheduler_metadata=None,
|
||||
fa_version: int = 2,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
num_splits=0,
|
||||
return_softmax_lse: bool | None = False,
|
||||
s_aux: torch.Tensor | None = None,
|
||||
):
|
||||
assert cu_seqlens_k is not None or seqused_k is not None, (
|
||||
"cu_seqlens_k or seqused_k must be provided"
|
||||
)
|
||||
assert cu_seqlens_k is None or seqused_k is None, (
|
||||
"cu_seqlens_k and seqused_k cannot be provided at the same time"
|
||||
)
|
||||
assert block_table is None or seqused_k is not None, (
|
||||
"when enable block_table, seqused_k is needed"
|
||||
)
|
||||
assert block_table is not None or cu_seqlens_k is not None, (
|
||||
"when block_table is disabled, cu_seqlens_k is needed"
|
||||
)
|
||||
if out is None:
|
||||
out = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
||||
real_window_size: tuple[int, int]
|
||||
if window_size is None:
|
||||
real_window_size = (-1, -1)
|
||||
else:
|
||||
assert len(window_size) == 2
|
||||
real_window_size = (window_size[0], window_size[1]) # noqa: F841
|
||||
|
||||
# In encode attention, k and v maybe not contiguous and current
|
||||
# kernel can't handle it
|
||||
if block_table is None:
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
return flash_attn_varlen_func(
|
||||
out=out,
|
||||
q=q.contiguous(),
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
block_table=block_table,
|
||||
s_aux=s_aux,
|
||||
window_size=real_window_size,
|
||||
# alibi_slopes = alibi_slopes,
|
||||
# softcap=softcap,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler_metadata(
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
headdim,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype=torch.bfloat16,
|
||||
headdim_v=None,
|
||||
cu_seqlens_q: torch.Tensor | None = None,
|
||||
cu_seqlens_k_new: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = None,
|
||||
page_size: int | None = None,
|
||||
max_seqlen_k_new=0,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite context window
|
||||
has_softcap=False,
|
||||
num_splits=0, # Can be tuned for speed
|
||||
pack_gqa=None, # Can be tuned for speed
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
) -> None:
|
||||
logger.warning_once(
|
||||
"get_scheduler_metadata is not implemented for xpu_ops, returning None."
|
||||
)
|
||||
return None
|
||||
0
vllm/assets/__init__.py
Normal file
0
vllm/assets/__init__.py
Normal file
43
vllm/assets/audio.py
Normal file
43
vllm/assets/audio.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .base import VLLM_S3_BUCKET_URL, get_vllm_public_assets
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
ASSET_DIR = "multimodal_asset"
|
||||
|
||||
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioAsset:
|
||||
name: AudioAssetName
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"{self.name}.ogg"
|
||||
|
||||
@property
|
||||
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||
audio_path = get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
|
||||
return librosa.load(audio_path, sr=None)
|
||||
|
||||
def get_local_path(self) -> Path:
|
||||
return get_vllm_public_assets(filename=self.filename, s3_prefix=ASSET_DIR)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")
|
||||
40
vllm/assets/base.py
Normal file
40
vllm/assets/base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
|
||||
VLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||
|
||||
|
||||
def get_cache_dir() -> Path:
|
||||
"""Get the path to the cache for storing downloaded assets."""
|
||||
path = Path(envs.VLLM_ASSETS_CACHE)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_vllm_public_assets(filename: str, s3_prefix: str | None = None) -> Path:
|
||||
"""
|
||||
Download an asset file from `s3://vllm-public-assets`
|
||||
and return the path to the downloaded file.
|
||||
"""
|
||||
asset_directory = get_cache_dir() / "vllm_public_assets"
|
||||
asset_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
asset_path = asset_directory / filename
|
||||
if not asset_path.exists():
|
||||
if s3_prefix is not None:
|
||||
filename = s3_prefix + "/" + filename
|
||||
global_http_connection.download_file(
|
||||
f"{VLLM_S3_BUCKET_URL}/{filename}",
|
||||
asset_path,
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
return asset_path
|
||||
62
vllm/assets/image.py
Normal file
62
vllm/assets/image.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import get_vllm_public_assets
|
||||
|
||||
VLM_IMAGES_DIR = "vision_model_images"
|
||||
|
||||
ImageAssetName = Literal[
|
||||
"stop_sign",
|
||||
"cherry_blossom",
|
||||
"hato",
|
||||
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk",
|
||||
"Grayscale_8bits_palette_sample_image",
|
||||
"1280px-Venn_diagram_rgb",
|
||||
"RGBA_comp",
|
||||
"237-400x300",
|
||||
"231-200x300",
|
||||
"27-500x500",
|
||||
"17-150x600",
|
||||
"handelsblatt-preview",
|
||||
"paper-11",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ImageAsset:
|
||||
name: ImageAssetName
|
||||
|
||||
def get_path(self, ext: str) -> Path:
|
||||
"""
|
||||
Return s3 path for given image.
|
||||
"""
|
||||
return get_vllm_public_assets(
|
||||
filename=f"{self.name}.{ext}", s3_prefix=VLM_IMAGES_DIR
|
||||
)
|
||||
|
||||
@property
|
||||
def pil_image(self) -> Image.Image:
|
||||
return self.pil_image_ext(ext="jpg")
|
||||
|
||||
def pil_image_ext(self, ext: str) -> Image.Image:
|
||||
image_path = self.get_path(ext=ext)
|
||||
return Image.open(image_path)
|
||||
|
||||
@property
|
||||
def image_embeds(self) -> torch.Tensor:
|
||||
"""
|
||||
Image embeddings, only used for testing purposes with llava 1.5.
|
||||
"""
|
||||
image_path = self.get_path("pt")
|
||||
return torch.load(image_path, map_location="cpu", weights_only=True)
|
||||
|
||||
def read_bytes(self, ext: str) -> bytes:
|
||||
p = Path(self.get_path(ext))
|
||||
return p.read_bytes()
|
||||
149
vllm/assets/video.py
Normal file
149
vllm/assets/video.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .base import get_cache_dir
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
|
||||
@lru_cache
|
||||
def download_video_asset(filename: str) -> str:
|
||||
"""
|
||||
Download and open an image from huggingface
|
||||
repo: raushan-testing-hf/videos-test
|
||||
"""
|
||||
video_directory = get_cache_dir() / "video-example-data"
|
||||
video_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_path = video_directory / filename
|
||||
video_path_str = str(video_path)
|
||||
if not video_path.exists():
|
||||
video_path_str = hf_hub_download(
|
||||
repo_id="raushan-testing-hf/videos-test",
|
||||
filename=filename,
|
||||
repo_type="dataset",
|
||||
cache_dir=video_directory,
|
||||
)
|
||||
return video_path_str
|
||||
|
||||
|
||||
def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video file {path}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frames = []
|
||||
|
||||
num_frames = num_frames if num_frames > 0 else total_frames
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
for idx in range(total_frames):
|
||||
ok = cap.grab() # next img
|
||||
if not ok:
|
||||
break
|
||||
if idx in frame_indices: # only decompress needed
|
||||
ret, frame = cap.retrieve()
|
||||
if ret:
|
||||
# OpenCV uses BGR format, we need to convert it to RGB
|
||||
# for PIL and transformers compatibility
|
||||
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
|
||||
frames = np.stack(frames)
|
||||
if len(frames) < num_frames:
|
||||
raise ValueError(
|
||||
f"Could not read enough frames from video file {path}"
|
||||
f" (expected {num_frames} frames, got {len(frames)})"
|
||||
)
|
||||
return frames
|
||||
|
||||
|
||||
def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]:
|
||||
frames = video_to_ndarrays(path, num_frames)
|
||||
return [Image.fromarray(frame) for frame in frames]
|
||||
|
||||
|
||||
def video_get_metadata(path: str, num_frames: int = -1) -> dict[str, Any]:
|
||||
import cv2
|
||||
|
||||
cap = cv2.VideoCapture(path)
|
||||
if not cap.isOpened():
|
||||
raise ValueError(f"Could not open video file {path}")
|
||||
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
if num_frames == -1 or num_frames > total_frames:
|
||||
num_frames = total_frames
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": duration / num_frames,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": list(range(num_frames)),
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": num_frames == total_frames,
|
||||
}
|
||||
return metadata
|
||||
|
||||
|
||||
VideoAssetName = Literal["baby_reading"]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VideoAsset:
|
||||
name: VideoAssetName
|
||||
num_frames: int = -1
|
||||
|
||||
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
|
||||
"baby_reading": "sample_demo_1.mp4",
|
||||
}
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self._NAME_TO_FILE[self.name]
|
||||
|
||||
@property
|
||||
def video_path(self) -> str:
|
||||
return download_video_asset(self.filename)
|
||||
|
||||
@property
|
||||
def pil_images(self) -> list[Image.Image]:
|
||||
ret = video_to_pil_images_list(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def np_ndarrays(self) -> npt.NDArray:
|
||||
ret = video_to_ndarrays(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
ret = video_get_metadata(self.video_path, self.num_frames)
|
||||
return ret
|
||||
|
||||
def get_audio(self, sampling_rate: float | None = None) -> npt.NDArray:
|
||||
"""
|
||||
Read audio data from the video asset, used in Qwen2.5-Omni examples.
|
||||
|
||||
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
|
||||
"""
|
||||
return librosa.load(self.video_path, sr=sampling_rate)[0]
|
||||
109
vllm/beam_search.py
Normal file
109
vllm/beam_search.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.inputs import TokenInputs, token_inputs
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchSequence:
|
||||
"""A sequence for beam search.
|
||||
It keeps track of the tokens and the log probability of the sequence.
|
||||
The text field is optional and will only be filled when the sequence is
|
||||
about to be returned to the user.
|
||||
"""
|
||||
|
||||
orig_prompt: TokenInputs | MultiModalInputs
|
||||
|
||||
# The tokens include the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
lora_request: LoRARequest | None = None
|
||||
cum_logprob: float = 0.0
|
||||
text: str | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
def get_prompt(self):
|
||||
prompt = self.orig_prompt
|
||||
|
||||
prompt_text = prompt.get("prompt")
|
||||
cache_salt = prompt.get("cache_salt")
|
||||
|
||||
if prompt["type"] == "token":
|
||||
return token_inputs(
|
||||
self.tokens,
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
return mm_inputs(
|
||||
prompt_token_ids=self.tokens,
|
||||
mm_kwargs=prompt["mm_kwargs"],
|
||||
mm_hashes=prompt["mm_hashes"],
|
||||
mm_placeholders=prompt["mm_placeholders"],
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeamSearchOutput:
|
||||
"""The output of beam search.
|
||||
It contains the list of the best beam search sequences.
|
||||
The length of the list is equal to the beam width.
|
||||
"""
|
||||
|
||||
sequences: list[BeamSearchSequence]
|
||||
|
||||
|
||||
class BeamSearchInstance:
|
||||
def __init__(
|
||||
self,
|
||||
prompt: TokenInputs | MultiModalInputs,
|
||||
lora_request: LoRARequest | None = None,
|
||||
logprobs: list[dict[int, Logprob]] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.beams: list[BeamSearchSequence] = [
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=prompt["prompt_token_ids"],
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
lora_request=lora_request,
|
||||
**kwargs,
|
||||
)
|
||||
]
|
||||
self.completed: list[BeamSearchSequence] = []
|
||||
|
||||
|
||||
def get_beam_search_score(
|
||||
tokens: list[int],
|
||||
cumulative_logprob: float,
|
||||
eos_token_id: int,
|
||||
length_penalty: float = 1.0,
|
||||
) -> float:
|
||||
"""Calculate the beam search score with length penalty.
|
||||
|
||||
Adapted from
|
||||
|
||||
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
|
||||
"""
|
||||
seq_len = len(tokens)
|
||||
if tokens[-1] == eos_token_id:
|
||||
seq_len -= 1
|
||||
|
||||
return cumulative_logprob / (seq_len**length_penalty)
|
||||
|
||||
|
||||
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(
|
||||
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
||||
)
|
||||
|
||||
return sort_beams_key
|
||||
0
vllm/benchmarks/__init__.py
Normal file
0
vllm/benchmarks/__init__.py
Normal file
3453
vllm/benchmarks/datasets.py
Normal file
3453
vllm/benchmarks/datasets.py
Normal file
File diff suppressed because it is too large
Load Diff
172
vllm/benchmarks/latency.py
Normal file
172
vllm/benchmarks/latency.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
||||
)
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# V1 enables prefix caching by default which skews the latency
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len + args.output_len
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
def run_to_completion(do_profile: bool = False):
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(do_profile=False)
|
||||
|
||||
if args.profile:
|
||||
profiler_config = engine_args.profiler_config
|
||||
if profiler_config.profiler == "torch":
|
||||
print(
|
||||
"Profiling with torch profiler (results will be saved to"
|
||||
f" {profiler_config.torch_profiler_dir})..."
|
||||
)
|
||||
elif profiler_config.profiler == "cuda":
|
||||
print("Profiling with cuda profiler ...")
|
||||
run_to_completion(do_profile=True)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Bench iterations"):
|
||||
latencies.append(run_to_completion(do_profile=False))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f"{percentage}% percentile latency: {percentile} seconds")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
539
vllm/benchmarks/mm_processor.py
Normal file
539
vllm/benchmarks/mm_processor.py
Normal file
@@ -0,0 +1,539 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
r"""Benchmark multimodal processor latency.
|
||||
|
||||
This benchmark measures the latency of the mm processor module
|
||||
using multimodal prompts from datasets.
|
||||
MM processor stats are automatically enabled.
|
||||
|
||||
Run:
|
||||
vllm bench mm-processor \
|
||||
--model <your_model> \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 10 \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
MultiModalConversationDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from vllm.benchmarks.throughput import get_requests
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils.gc_utils import freeze_gc_heap
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
if TYPE_CHECKING: # Avoid having to mock during docs build
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
else:
|
||||
LLMEngine = object
|
||||
|
||||
|
||||
def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Get all multimodal timing stats from the LLM engine.
|
||||
|
||||
Collects both preprocessing stats (HF processor, hashing, cache lookup,
|
||||
prompt update) and encoder forward pass timing, merged by request_id.
|
||||
|
||||
Args:
|
||||
llm_engine: The LLM engine (has input_processor and workers).
|
||||
|
||||
Returns:
|
||||
Dictionary mapping request_id to merged stats dict containing
|
||||
both preprocessing and encoder timing metrics.
|
||||
|
||||
Example:
|
||||
{
|
||||
'request-123': {
|
||||
'get_mm_hashes_secs': 0.02,
|
||||
'get_cache_missing_items_secs': 0.01,
|
||||
'apply_hf_processor_secs': 0.45,
|
||||
'merge_mm_kwargs_secs': 0.01,
|
||||
'apply_prompt_updates_secs': 0.03,
|
||||
'preprocessor_total_secs': 0.51,
|
||||
'encoder_forward_secs': 0.23,
|
||||
'num_encoder_calls': 1
|
||||
}
|
||||
}
|
||||
"""
|
||||
observability_config = llm_engine.vllm_config.observability_config
|
||||
if not observability_config or not observability_config.enable_mm_processor_stats:
|
||||
return {}
|
||||
|
||||
renderer = llm_engine.renderer
|
||||
mm_processor_stats = renderer._mm_timing_registry.stat()
|
||||
|
||||
encoder_stats = dict[str, dict[str, float]]()
|
||||
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
|
||||
if not worker_stats:
|
||||
continue
|
||||
|
||||
for request_id, stats_dict in worker_stats.items():
|
||||
if request_id not in encoder_stats:
|
||||
encoder_stats[request_id] = dict(stats_dict)
|
||||
else:
|
||||
# Aggregate timing metrics across workers
|
||||
current_time = encoder_stats[request_id].get(
|
||||
"encoder_forward_secs", 0.0
|
||||
)
|
||||
new_time = stats_dict.get("encoder_forward_secs", 0.0)
|
||||
encoder_stats[request_id]["encoder_forward_secs"] = max(
|
||||
current_time, new_time
|
||||
)
|
||||
|
||||
current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
|
||||
new_calls = stats_dict.get("num_encoder_calls", 0)
|
||||
encoder_stats[request_id]["num_encoder_calls"] = max(
|
||||
current_calls, new_calls
|
||||
)
|
||||
|
||||
merged_stats = dict[str, dict[str, float]]()
|
||||
|
||||
for request_id, prep_dict in mm_processor_stats.items():
|
||||
merged_stats[request_id] = dict(prep_dict)
|
||||
|
||||
for request_id, enc_dict in encoder_stats.items():
|
||||
if request_id in merged_stats:
|
||||
merged_stats[request_id].update(enc_dict)
|
||||
continue
|
||||
|
||||
# In V1 engine, the request_id in encoder_stats has a suffix
|
||||
# appended to the original request_id (which is used in
|
||||
# preprocessing_stats).
|
||||
# We try to strip the suffix to find the matching request.
|
||||
possible_original_id = request_id.rpartition("-")[0]
|
||||
if possible_original_id and possible_original_id in merged_stats:
|
||||
merged_stats[possible_original_id].update(enc_dict)
|
||||
else:
|
||||
merged_stats[request_id] = dict(enc_dict)
|
||||
|
||||
return merged_stats
|
||||
|
||||
|
||||
def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
|
||||
"""
|
||||
Collect multimodal processor timing stats.
|
||||
Returns a dictionary mapping stage names to lists of timing values (in seconds).
|
||||
"""
|
||||
all_stats = get_timing_stats_from_engine(llm_engine)
|
||||
|
||||
stats_by_stage = defaultdict[str, list[float]](list)
|
||||
|
||||
for stats_dict in all_stats.values():
|
||||
for stat_key, stat_val in stats_dict.items():
|
||||
stats_by_stage[stat_key].append(stat_val)
|
||||
|
||||
return stats_by_stage
|
||||
|
||||
|
||||
def calculate_mm_processor_metrics(
|
||||
stats_by_stage: dict[str, list[float]],
|
||||
selected_percentiles: list[float],
|
||||
*,
|
||||
unit: Literal["us", "ms", "s"] = "ms",
|
||||
) -> dict[str, dict[str, float]]:
|
||||
"""
|
||||
Calculate aggregate metrics from stats by stage.
|
||||
"""
|
||||
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
|
||||
unit_mult = unit2mult[unit]
|
||||
|
||||
metrics = {}
|
||||
|
||||
for stage, times in stats_by_stage.items():
|
||||
stage_name = stage.replace("_secs", "_" + unit)
|
||||
|
||||
if not times:
|
||||
metrics[stage_name] = {
|
||||
"mean": 0.0,
|
||||
"median": 0.0,
|
||||
"std": 0.0,
|
||||
**{f"p{p}": 0.0 for p in selected_percentiles},
|
||||
}
|
||||
continue
|
||||
|
||||
is_count_metric = stage == "num_encoder_calls"
|
||||
values = times if is_count_metric else [t * unit_mult for t in times]
|
||||
|
||||
metrics[stage_name] = {
|
||||
"mean": float(np.mean(values)),
|
||||
"median": float(np.median(values)),
|
||||
"std": float(np.std(values)),
|
||||
**{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments for mm_processor benchmark.
|
||||
"""
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
if not hasattr(args, "dataset_path"):
|
||||
args.dataset_path = None
|
||||
if not hasattr(args, "lora_path"):
|
||||
args.lora_path = None
|
||||
if not hasattr(args, "max_loras"):
|
||||
args.max_loras = None
|
||||
|
||||
if args.dataset_name == "hf" and not args.dataset_path:
|
||||
raise ValueError(
|
||||
"--dataset-path is required when using --dataset-name hf. "
|
||||
"For multimodal benchmarking, specify a dataset like "
|
||||
"'lmarena-ai/VisionArena-Chat'."
|
||||
)
|
||||
if args.dataset_name == "hf":
|
||||
supported_mm_datasets = (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
)
|
||||
if args.dataset_path not in supported_mm_datasets:
|
||||
raise ValueError(
|
||||
f"{args.dataset_path} is not a supported multimodal dataset. "
|
||||
f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
|
||||
)
|
||||
|
||||
|
||||
def benchmark_multimodal_processor(
|
||||
args: argparse.Namespace,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run the multimodal processor benchmark.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
validate_args(args)
|
||||
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
tokenizer = llm.get_tokenizer()
|
||||
requests = get_requests(args, tokenizer)
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = [request.prompt for request in requests]
|
||||
expected_output_lens = [request.expected_output_len for request in requests]
|
||||
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
n=1,
|
||||
temperature=0.0,
|
||||
max_tokens=output_len,
|
||||
detokenize=True,
|
||||
)
|
||||
for output_len in expected_output_lens
|
||||
]
|
||||
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
|
||||
freeze_gc_heap()
|
||||
|
||||
num_warmups = getattr(args, "num_warmups", 0)
|
||||
if num_warmups > 0:
|
||||
print(f"Processing {num_warmups} warmup requests...")
|
||||
# Create a temporary args object for warmup requests
|
||||
warmup_args = argparse.Namespace(**vars(args))
|
||||
warmup_args.num_prompts = num_warmups
|
||||
warmup_args.seed += 1
|
||||
warmup_requests = get_requests(warmup_args, tokenizer)
|
||||
warmup_prompts = [req.prompt for req in warmup_requests]
|
||||
warmup_output_lens = [req.expected_output_len for req in warmup_requests]
|
||||
warmup_sampling_params = [
|
||||
SamplingParams(max_tokens=output_len) for output_len in warmup_output_lens
|
||||
]
|
||||
llm.chat(
|
||||
warmup_prompts,
|
||||
warmup_sampling_params,
|
||||
use_tqdm=not getattr(args, "disable_tqdm", False),
|
||||
)
|
||||
|
||||
# Clear stats from warmup requests
|
||||
collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
print(f"Processing {len(prompts)} requests...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
outputs = llm.chat(
|
||||
prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
total_time = end_time - start_time
|
||||
|
||||
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
|
||||
|
||||
if not any(mm_stats_by_stage.values()):
|
||||
print(
|
||||
"\n⚠️ Warning: No MM processor stats found in registry.\n"
|
||||
" This may indicate that:\n"
|
||||
" - No multimodal requests were processed\n"
|
||||
" - Stats were already retrieved (registry is cleared after retrieval)\n"
|
||||
)
|
||||
|
||||
mm_processor_metrics = calculate_mm_processor_metrics(
|
||||
mm_stats_by_stage, selected_percentiles
|
||||
)
|
||||
|
||||
completed = len([o for o in outputs if o.finished])
|
||||
failed = len(outputs) - completed
|
||||
|
||||
e2el_times = []
|
||||
for output in outputs:
|
||||
if not output.finished or output.metrics is None:
|
||||
continue
|
||||
metrics = output.metrics
|
||||
# Calculate E2E latency as: TTFT + (last_token_ts - first_token_ts)
|
||||
if (
|
||||
getattr(metrics, "first_token_latency", None) is not None
|
||||
and getattr(metrics, "last_token_ts", None) is not None
|
||||
and getattr(metrics, "first_token_ts", None) is not None
|
||||
):
|
||||
ttft = metrics.first_token_latency
|
||||
# Decode time is the duration between the first and last token generation
|
||||
decode_time = max(0.0, metrics.last_token_ts - metrics.first_token_ts)
|
||||
e2el_times.append((ttft + decode_time) * 1000)
|
||||
|
||||
if not e2el_times and completed > 0:
|
||||
print(
|
||||
"\n⚠️ Warning: Detailed end-to-end latency metrics not available.\n"
|
||||
" Falling back to average request latency "
|
||||
"(total_time / num_completed_requests).\n"
|
||||
)
|
||||
avg_time_per_request = total_time / completed
|
||||
e2el_times = [avg_time_per_request * 1000] * completed
|
||||
|
||||
if e2el_times:
|
||||
mean_e2el_ms = float(np.mean(e2el_times))
|
||||
median_e2el_ms = float(np.median(e2el_times))
|
||||
std_e2el_ms = float(np.std(e2el_times))
|
||||
percentiles_e2el_ms = [
|
||||
(p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
|
||||
]
|
||||
else:
|
||||
mean_e2el_ms = 0.0
|
||||
median_e2el_ms = 0.0
|
||||
std_e2el_ms = 0.0
|
||||
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
|
||||
|
||||
encoder_summary = {}
|
||||
if (
|
||||
"num_encoder_calls" in mm_stats_by_stage
|
||||
and mm_stats_by_stage["num_encoder_calls"]
|
||||
):
|
||||
encoder_calls = mm_stats_by_stage["num_encoder_calls"]
|
||||
encoder_summary = {
|
||||
"total_encoder_calls": int(sum(encoder_calls)),
|
||||
"num_requests_with_encoder_calls": len(encoder_calls),
|
||||
}
|
||||
|
||||
benchmark_result = {
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"mean_e2el_ms": mean_e2el_ms,
|
||||
"median_e2el_ms": median_e2el_ms,
|
||||
"std_e2el_ms": std_e2el_ms,
|
||||
"percentiles_e2el_ms": percentiles_e2el_ms,
|
||||
"mm_processor_stats": mm_processor_metrics,
|
||||
"encoder_summary": encoder_summary,
|
||||
}
|
||||
|
||||
return benchmark_result
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser) -> None:
|
||||
"""Add CLI arguments for the multimodal processor benchmark."""
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.set_defaults(enable_mm_processor_stats=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="random-mm",
|
||||
choices=["random-mm", "hf"],
|
||||
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-warmups",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup prompts to process.",
|
||||
)
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
add_random_dataset_base_args,
|
||||
add_random_multimodal_dataset_args,
|
||||
)
|
||||
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
# HuggingFace dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset file or HuggingFace dataset name "
|
||||
"(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HuggingFace dataset (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. "
|
||||
"Overrides the default output lengths from the dataset.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the benchmark results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
action="store_true",
|
||||
help="Disable tqdm progress bar.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
"""Main entry point for the multimodal processor benchmark."""
|
||||
|
||||
print("Starting multimodal processor benchmark...")
|
||||
result = benchmark_multimodal_processor(args)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("Multimodal Processor Benchmark Results")
|
||||
print("=" * 80)
|
||||
|
||||
if "mm_processor_stats" in result:
|
||||
print("\nMM Processor Metrics:")
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
mm_data = []
|
||||
for stage, metrics in result["mm_processor_stats"].items():
|
||||
row = {
|
||||
"Stage": stage,
|
||||
"Mean": f"{metrics['mean']:.2f}",
|
||||
"Median": f"{metrics['median']:.2f}",
|
||||
"Std": f"{metrics['std']:.2f}",
|
||||
}
|
||||
for p in selected_percentiles:
|
||||
row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
|
||||
mm_data.append(row)
|
||||
|
||||
mm_df = pd.DataFrame(mm_data)
|
||||
print(mm_df.to_string(index=False))
|
||||
|
||||
if "encoder_summary" in result and result["encoder_summary"]:
|
||||
total_calls = result["encoder_summary"]["total_encoder_calls"]
|
||||
num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
|
||||
print(
|
||||
f"\nSummary: {total_calls} total encoder calls "
|
||||
f"across {num_requests} requests."
|
||||
)
|
||||
|
||||
if "mean_e2el_ms" in result:
|
||||
print("\nEnd-to-End Latency (ms):")
|
||||
selected_percentiles = [
|
||||
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
|
||||
]
|
||||
|
||||
e2el_data = [
|
||||
{"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
|
||||
{"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
|
||||
{"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
|
||||
]
|
||||
|
||||
for p in selected_percentiles:
|
||||
percentile_value = next(
|
||||
(val for pct, val in result["percentiles_e2el_ms"] if pct == p),
|
||||
0.0,
|
||||
)
|
||||
e2el_data.append(
|
||||
{
|
||||
"Metric": f"P{p}",
|
||||
"Value (ms)": f"{percentile_value:.2f}",
|
||||
}
|
||||
)
|
||||
|
||||
e2el_df = pd.DataFrame(e2el_data)
|
||||
print(e2el_df.to_string(index=False))
|
||||
|
||||
if args.output_json:
|
||||
result["config"] = {
|
||||
"model": args.model,
|
||||
"num_prompts": args.num_prompts,
|
||||
"input_len": getattr(args, "random_input_len", None),
|
||||
"output_len": getattr(args, "random_output_len", None),
|
||||
}
|
||||
result["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"\nResults saved to {args.output_json}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark mm processor latency")
|
||||
add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
1816
vllm/benchmarks/serve.py
Normal file
1816
vllm/benchmarks/serve.py
Normal file
File diff suppressed because it is too large
Load Diff
321
vllm/benchmarks/startup.py
Normal file
321
vllm/benchmarks/startup.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the cold and warm startup time of vLLM models.
|
||||
|
||||
This script measures total startup time (including model loading, compilation,
|
||||
and cache operations) for both cold and warm scenarios:
|
||||
- Cold startup: Fresh start with no caches (temporary cache directories)
|
||||
- Warm startup: Using cached compilation and model info
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.benchmarks.lib.utils import (
|
||||
convert_to_pytorch_benchmark_format,
|
||||
write_to_json,
|
||||
)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
|
||||
@contextmanager
|
||||
def cold_startup():
|
||||
"""
|
||||
Context manager to measure cold startup time:
|
||||
1. Uses a temporary directory for vLLM cache to avoid any pollution
|
||||
between cold startup iterations.
|
||||
2. Uses inductor's fresh_cache to clear torch.compile caches.
|
||||
"""
|
||||
from torch._inductor.utils import fresh_cache
|
||||
|
||||
# Use temporary directory for caching to avoid any pollution between cold startups
|
||||
original_cache_root = os.environ.get("VLLM_CACHE_ROOT")
|
||||
temp_cache_dir = tempfile.mkdtemp(prefix="vllm_startup_bench_cold_")
|
||||
try:
|
||||
os.environ["VLLM_CACHE_ROOT"] = temp_cache_dir
|
||||
with fresh_cache():
|
||||
yield
|
||||
finally:
|
||||
# Clean up temporary cache directory
|
||||
shutil.rmtree(temp_cache_dir, ignore_errors=True)
|
||||
if original_cache_root:
|
||||
os.environ["VLLM_CACHE_ROOT"] = original_cache_root
|
||||
else:
|
||||
os.environ.pop("VLLM_CACHE_ROOT", None)
|
||||
|
||||
|
||||
def run_startup_in_subprocess(engine_args, result_queue):
|
||||
"""
|
||||
Run LLM startup in a subprocess and return timing metrics via a queue.
|
||||
This ensures complete isolation between iterations.
|
||||
"""
|
||||
try:
|
||||
# Import inside the subprocess to avoid issues with forking
|
||||
from vllm import LLM
|
||||
|
||||
# Measure total startup time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
total_startup_time = time.perf_counter() - start_time
|
||||
|
||||
# Extract compilation time if available
|
||||
compilation_time = 0.0
|
||||
if hasattr(llm.llm_engine, "vllm_config"):
|
||||
vllm_config = llm.llm_engine.vllm_config
|
||||
if (
|
||||
hasattr(vllm_config, "compilation_config")
|
||||
and vllm_config.compilation_config is not None
|
||||
):
|
||||
compilation_time = vllm_config.compilation_config.compilation_time
|
||||
|
||||
result_queue.put(
|
||||
{
|
||||
"total_startup_time": total_startup_time,
|
||||
"compilation_time": compilation_time,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result_queue.put(None)
|
||||
result_queue.put(str(e))
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
base_name = os.path.splitext(args.output_json)[0]
|
||||
|
||||
cold_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_startup_time": [results["avg_cold_startup_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"cold_startup_times": results["cold_startup_times"],
|
||||
"cold_startup_percentiles": results["cold_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_startup_records:
|
||||
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
|
||||
|
||||
cold_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_cold_compilation_time": [results["avg_cold_compilation_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"cold_compilation_times": results["cold_compilation_times"],
|
||||
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if cold_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
|
||||
)
|
||||
|
||||
warm_startup_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_startup_time": [results["avg_warm_startup_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"warm_startup_times": results["warm_startup_times"],
|
||||
"warm_startup_percentiles": results["warm_startup_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_startup_records:
|
||||
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
|
||||
|
||||
warm_compilation_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"avg_warm_compilation_time": [results["avg_warm_compilation_time"]],
|
||||
},
|
||||
extra_info={
|
||||
"warm_compilation_times": results["warm_compilation_times"],
|
||||
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
|
||||
},
|
||||
)
|
||||
if warm_compilation_records:
|
||||
write_to_json(
|
||||
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-iters-cold",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of cold startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of warmup iterations before benchmarking warm startups.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters-warm",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of warm startup iterations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the startup time results in JSON format.",
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# Set multiprocessing start method to 'spawn' for clean process isolation
|
||||
# This ensures each subprocess starts fresh without inheriting state
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
def create_llm_and_measure_startup():
|
||||
"""
|
||||
Create LLM instance in a subprocess and measure startup time.
|
||||
Returns timing metrics, using subprocess for complete isolation.
|
||||
"""
|
||||
|
||||
# Create a queue for inter-process communication
|
||||
result_queue = multiprocessing.Queue()
|
||||
process = multiprocessing.Process(
|
||||
target=run_startup_in_subprocess,
|
||||
args=(
|
||||
engine_args,
|
||||
result_queue,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
process.join()
|
||||
|
||||
if not result_queue.empty():
|
||||
result = result_queue.get()
|
||||
if result is None:
|
||||
if not result_queue.empty():
|
||||
error_msg = result_queue.get()
|
||||
raise RuntimeError(f"Subprocess failed: {error_msg}")
|
||||
else:
|
||||
raise RuntimeError("Subprocess failed with unknown error")
|
||||
return result
|
||||
else:
|
||||
raise RuntimeError("Subprocess did not return a result")
|
||||
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
|
||||
|
||||
print("Measuring cold startup time...\n")
|
||||
cold_startup_times = []
|
||||
cold_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
|
||||
with cold_startup():
|
||||
metrics = create_llm_and_measure_startup()
|
||||
cold_startup_times.append(metrics["total_startup_time"])
|
||||
cold_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Warmup for warm startup
|
||||
print("\nWarming up for warm startup measurement...\n")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
create_llm_and_measure_startup()
|
||||
|
||||
print("\nMeasuring warm startup time...\n")
|
||||
warm_startup_times = []
|
||||
warm_compilation_times = []
|
||||
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
|
||||
metrics = create_llm_and_measure_startup()
|
||||
warm_startup_times.append(metrics["total_startup_time"])
|
||||
warm_compilation_times.append(metrics["compilation_time"])
|
||||
|
||||
# Calculate statistics
|
||||
cold_startup_array = np.array(cold_startup_times)
|
||||
cold_compilation_array = np.array(cold_compilation_times)
|
||||
warm_startup_array = np.array(warm_startup_times)
|
||||
warm_compilation_array = np.array(warm_compilation_times)
|
||||
|
||||
avg_cold_startup = np.mean(cold_startup_array)
|
||||
avg_cold_compilation = np.mean(cold_compilation_array)
|
||||
avg_warm_startup = np.mean(warm_startup_array)
|
||||
avg_warm_compilation = np.mean(warm_compilation_array)
|
||||
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
|
||||
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
|
||||
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
|
||||
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("STARTUP TIME BENCHMARK RESULTS")
|
||||
print("=" * 60)
|
||||
|
||||
# Cold startup statistics
|
||||
print("\nCOLD STARTUP:")
|
||||
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
# Warm startup statistics
|
||||
print("\nWARM STARTUP:")
|
||||
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
|
||||
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
|
||||
print("Startup time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_startup_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
print("Compilation time percentiles:")
|
||||
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
|
||||
print(f" {percentage}%: {percentile:.2f} seconds")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_cold_startup_time": float(avg_cold_startup),
|
||||
"avg_cold_compilation_time": float(avg_cold_compilation),
|
||||
"cold_startup_times": cold_startup_times,
|
||||
"cold_compilation_times": cold_compilation_times,
|
||||
"cold_startup_percentiles": dict(
|
||||
zip(percentages, cold_startup_percentiles.tolist())
|
||||
),
|
||||
"cold_compilation_percentiles": dict(
|
||||
zip(percentages, cold_compilation_percentiles.tolist())
|
||||
),
|
||||
"avg_warm_startup_time": float(avg_warm_startup),
|
||||
"avg_warm_compilation_time": float(avg_warm_compilation),
|
||||
"warm_startup_times": warm_startup_times,
|
||||
"warm_compilation_times": warm_compilation_times,
|
||||
"warm_startup_percentiles": dict(
|
||||
zip(percentages, warm_startup_percentiles.tolist())
|
||||
),
|
||||
"warm_compilation_percentiles": dict(
|
||||
zip(percentages, warm_compilation_percentiles.tolist())
|
||||
),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
0
vllm/benchmarks/sweep/__init__.py
Normal file
0
vllm/benchmarks/sweep/__init__.py
Normal file
44
vllm/benchmarks/sweep/cli.py
Normal file
44
vllm/benchmarks/sweep/cli.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
|
||||
|
||||
from .plot import SweepPlotArgs
|
||||
from .plot import main as plot_main
|
||||
from .plot_pareto import SweepPlotParetoArgs
|
||||
from .plot_pareto import main as plot_pareto_main
|
||||
from .serve import SweepServeArgs
|
||||
from .serve import main as serve_main
|
||||
from .serve_sla import SweepServeSLAArgs
|
||||
from .serve_sla import main as serve_sla_main
|
||||
from .startup import SweepStartupArgs
|
||||
from .startup import main as startup_main
|
||||
|
||||
SUBCOMMANDS = (
|
||||
(SweepServeArgs, serve_main),
|
||||
(SweepServeSLAArgs, serve_sla_main),
|
||||
(SweepStartupArgs, startup_main),
|
||||
(SweepPlotArgs, plot_main),
|
||||
(SweepPlotParetoArgs, plot_pareto_main),
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
|
||||
|
||||
for cmd, entrypoint in SUBCOMMANDS:
|
||||
cmd_subparser = subparsers.add_parser(
|
||||
cmd.parser_name,
|
||||
description=cmd.parser_help,
|
||||
usage=f"vllm bench sweep {cmd.parser_name} [options]",
|
||||
)
|
||||
cmd_subparser.set_defaults(dispatch_function=entrypoint)
|
||||
cmd.add_cli_args(cmd_subparser)
|
||||
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
|
||||
subcmd=f"sweep {cmd.parser_name}"
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
args.dispatch_function(args)
|
||||
159
vllm/benchmarks/sweep/param_sweep.py
Normal file
159
vllm/benchmarks/sweep/param_sweep.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ParameterSweep(list["ParameterSweepItem"]):
|
||||
@classmethod
|
||||
def read_json(cls, filepath: os.PathLike):
|
||||
with open(filepath, "rb") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Support both list and dict formats
|
||||
if isinstance(data, dict):
|
||||
return cls.read_from_dict(data)
|
||||
|
||||
return cls.from_records(data)
|
||||
|
||||
@classmethod
|
||||
def read_from_dict(cls, data: dict[str, dict[str, object]]):
|
||||
"""
|
||||
Read parameter sweep from a dict format where keys are names.
|
||||
|
||||
Example:
|
||||
{
|
||||
"experiment1": {"max_tokens": 100, "temperature": 0.7},
|
||||
"experiment2": {"max_tokens": 200, "temperature": 0.9}
|
||||
}
|
||||
"""
|
||||
records = [{"_benchmark_name": name, **params} for name, params in data.items()]
|
||||
return cls.from_records(records)
|
||||
|
||||
@classmethod
|
||||
def from_records(cls, records: list[dict[str, object]]):
|
||||
if not isinstance(records, list):
|
||||
raise TypeError(
|
||||
f"The parameter sweep should be a list of dictionaries, "
|
||||
f"but found type: {type(records)}"
|
||||
)
|
||||
|
||||
# Validate that all _benchmark_name values are unique if provided
|
||||
names = [r["_benchmark_name"] for r in records if "_benchmark_name" in r]
|
||||
if names and len(names) != len(set(names)):
|
||||
duplicates = [name for name in names if names.count(name) > 1]
|
||||
raise ValueError(
|
||||
f"Duplicate _benchmark_name values found: {set(duplicates)}. "
|
||||
f"All _benchmark_name values must be unique."
|
||||
)
|
||||
|
||||
return cls(ParameterSweepItem.from_record(record) for record in records)
|
||||
|
||||
|
||||
class ParameterSweepItem(dict[str, object]):
|
||||
@classmethod
|
||||
def from_record(cls, record: dict[str, object]):
|
||||
if not isinstance(record, dict):
|
||||
raise TypeError(
|
||||
f"Each item in the parameter sweep should be a dictionary, "
|
||||
f"but found type: {type(record)}"
|
||||
)
|
||||
|
||||
return cls(record)
|
||||
|
||||
def __or__(self, other: dict[str, Any]):
|
||||
return type(self)(super().__or__(other))
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""
|
||||
Get the name for this parameter sweep item.
|
||||
|
||||
Returns the '_benchmark_name' field if present, otherwise returns a text
|
||||
representation of all parameters.
|
||||
"""
|
||||
if "_benchmark_name" in self:
|
||||
return str(self["_benchmark_name"])
|
||||
|
||||
return self.as_text(sep="-")
|
||||
|
||||
# In JSON, we prefer "_"
|
||||
def _iter_param_key_candidates(self, param_key: str):
|
||||
# Inner config arguments are not converted by the CLI
|
||||
if "." in param_key:
|
||||
prefix, rest = param_key.split(".", 1)
|
||||
for prefix_candidate in self._iter_param_key_candidates(prefix):
|
||||
yield prefix_candidate + "." + rest
|
||||
|
||||
return
|
||||
|
||||
yield param_key
|
||||
yield param_key.replace("-", "_")
|
||||
yield param_key.replace("_", "-")
|
||||
|
||||
# In CLI, we prefer "-"
|
||||
def _iter_cmd_key_candidates(self, param_key: str):
|
||||
for k in reversed(tuple(self._iter_param_key_candidates(param_key))):
|
||||
yield "--" + k
|
||||
|
||||
def _normalize_cmd_key(self, param_key: str):
|
||||
return next(self._iter_cmd_key_candidates(param_key))
|
||||
|
||||
def has_param(self, param_key: str) -> bool:
|
||||
return any(k in self for k in self._iter_param_key_candidates(param_key))
|
||||
|
||||
def _normalize_cmd_kv_pair(self, k: str, v: object) -> list[str]:
|
||||
"""
|
||||
Normalize a key-value pair into command-line arguments.
|
||||
|
||||
Returns a list containing either:
|
||||
- A single element for boolean flags (e.g., ['--flag'] or ['--flag=true'])
|
||||
- Two elements for key-value pairs (e.g., ['--key', 'value'])
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
# For nested params (containing "."), use =true/false syntax
|
||||
if "." in k:
|
||||
return [f"{self._normalize_cmd_key(k)}={'true' if v else 'false'}"]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k if v else "no-" + k)]
|
||||
else:
|
||||
return [self._normalize_cmd_key(k), str(v)]
|
||||
|
||||
def apply_to_cmd(self, cmd: list[str]) -> list[str]:
|
||||
cmd = list(cmd)
|
||||
|
||||
for k, v in self.items():
|
||||
# Skip the '_benchmark_name' field, not a parameter
|
||||
if k == "_benchmark_name":
|
||||
continue
|
||||
|
||||
# Serialize dict values as JSON
|
||||
if isinstance(v, dict):
|
||||
v = json.dumps(v)
|
||||
|
||||
for k_candidate in self._iter_cmd_key_candidates(k):
|
||||
try:
|
||||
k_idx = cmd.index(k_candidate)
|
||||
|
||||
# Replace existing parameter
|
||||
normalized = self._normalize_cmd_kv_pair(k, v)
|
||||
if len(normalized) == 1:
|
||||
# Boolean flag
|
||||
cmd[k_idx] = normalized[0]
|
||||
else:
|
||||
# Key-value pair
|
||||
cmd[k_idx] = normalized[0]
|
||||
cmd[k_idx + 1] = normalized[1]
|
||||
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
# Add new parameter
|
||||
cmd.extend(self._normalize_cmd_kv_pair(k, v))
|
||||
|
||||
return cmd
|
||||
|
||||
def as_text(self, sep: str = ", ") -> str:
|
||||
return sep.join(f"{k}={v}" for k, v in self.items() if k != "_benchmark_name")
|
||||
683
vllm/benchmarks/sweep/plot.py
Normal file
683
vllm/benchmarks/sweep/plot.py
Normal file
@@ -0,0 +1,683 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import TracebackType
|
||||
from typing import ClassVar
|
||||
|
||||
from typing_extensions import Self, override
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotFilterBase(ABC):
|
||||
var: str
|
||||
target: str
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_FILTERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_FILTERS[op_key](
|
||||
key,
|
||||
value.removeprefix(op_key).strip("'").strip('"'),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot filter '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_FILTERS)}",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this filter to a DataFrame."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] == target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotNotEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
try:
|
||||
target = float(self.target)
|
||||
except ValueError:
|
||||
target = self.target
|
||||
|
||||
return df[df[self.var] != target]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] < float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotLessThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] <= float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThan(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] > float(self.target)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotGreaterThanOrEqualTo(PlotFilterBase):
|
||||
@override
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
return df[df[self.var] >= float(self.target)]
|
||||
|
||||
|
||||
# NOTE: The ordering is important! Match longer op_keys first
|
||||
PLOT_FILTERS: dict[str, type[PlotFilterBase]] = {
|
||||
"==": PlotEqualTo,
|
||||
"!=": PlotNotEqualTo,
|
||||
"<=": PlotLessThanOrEqualTo,
|
||||
">=": PlotGreaterThanOrEqualTo,
|
||||
"<": PlotLessThan,
|
||||
">": PlotGreaterThan,
|
||||
}
|
||||
|
||||
|
||||
class PlotFilters(list[PlotFilterBase]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlotBinner:
|
||||
var: str
|
||||
bin_size: float
|
||||
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
for op_key in PLOT_BINNERS:
|
||||
if op_key in s:
|
||||
key, value = s.split(op_key)
|
||||
return PLOT_BINNERS[op_key](key, float(value.removeprefix(op_key)))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid operator for plot binner '{s}'. "
|
||||
f"Valid operators are: {sorted(PLOT_BINNERS)}",
|
||||
)
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
"""Applies this binner to a DataFrame."""
|
||||
df = df.copy()
|
||||
df[self.var] = df[self.var] // self.bin_size * self.bin_size
|
||||
return df
|
||||
|
||||
|
||||
PLOT_BINNERS: dict[str, type[PlotBinner]] = {
|
||||
"%": PlotBinner,
|
||||
}
|
||||
|
||||
|
||||
class PlotBinners(list[PlotBinner]):
|
||||
@classmethod
|
||||
def parse_str(cls, s: str):
|
||||
if not s:
|
||||
return cls()
|
||||
|
||||
return cls(PlotBinner.parse_str(e) for e in s.split(","))
|
||||
|
||||
def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
|
||||
for item in self:
|
||||
df = item.apply(df)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def _json_load_bytes(path: Path) -> list[dict[str, object]]:
|
||||
with path.open("rb") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _convert_inf_nan_strings(data: list[dict[str, object]]) -> list[dict[str, object]]:
|
||||
"""
|
||||
Convert string values "inf", "-inf", and "nan" to their float equivalents.
|
||||
|
||||
This handles the case where JSON serialization represents inf/nan as strings.
|
||||
"""
|
||||
converted_data = []
|
||||
for record in data:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
if isinstance(value, str):
|
||||
if value in ["inf", "-inf", "nan"]:
|
||||
converted_record[key] = float(value)
|
||||
else:
|
||||
converted_record[key] = value
|
||||
else:
|
||||
converted_record[key] = value
|
||||
converted_data.append(converted_record)
|
||||
return converted_data
|
||||
|
||||
|
||||
def _get_metric(run_data: dict[str, object], metric_key: str):
|
||||
try:
|
||||
return run_data[metric_key]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Cannot find metric {metric_key!r} in {run_data=}") from exc
|
||||
|
||||
|
||||
def _get_group(run_data: dict[str, object], group_keys: list[str]):
|
||||
return tuple((k, str(_get_metric(run_data, k))) for k in group_keys)
|
||||
|
||||
|
||||
def _get_fig_path(fig_dir: Path, group: tuple[tuple[str, str], ...], fig_name: str):
|
||||
parts = list[str]()
|
||||
|
||||
# Start with figure name (always provided, defaults to "FIGURE")
|
||||
parts.append(fig_name)
|
||||
|
||||
# Always append group data if present
|
||||
if group:
|
||||
parts.extend(f"{k}={v}" for k, v in group)
|
||||
|
||||
return fig_dir / sanitize_filename("-".join(parts) + ".png")
|
||||
|
||||
|
||||
class DummyExecutor:
|
||||
map = map
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str,
|
||||
error_bars: bool,
|
||||
fig_height: float,
|
||||
fig_dpi: int,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
|
||||
row_groups = full_groupby(
|
||||
fig_data,
|
||||
key=lambda item: _get_group(item, row_by),
|
||||
)
|
||||
num_rows = len(row_groups)
|
||||
num_cols = max(
|
||||
len(full_groupby(row_data, key=lambda item: _get_group(item, col_by)))
|
||||
for _, row_data in row_groups
|
||||
)
|
||||
|
||||
fig_path = _get_fig_path(fig_dir, fig_group, fig_name)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Grid: {num_rows} rows x {num_cols} cols")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
# Convert string "inf", "-inf", and "nan" to their float equivalents
|
||||
fig_data = _convert_inf_nan_strings(fig_data)
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
|
||||
if var_x not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_x=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
if var_y not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find {var_y=!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in row_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find row_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in col_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find col_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
for k in curve_by:
|
||||
if k not in df.columns:
|
||||
raise ValueError(
|
||||
f"Cannot find curve_by={k!r} in parameter sweep results. "
|
||||
f"Available variables: {df.columns.tolist()}"
|
||||
)
|
||||
|
||||
df = filter_by.apply(df)
|
||||
df = bin_by.apply(df)
|
||||
|
||||
# Sort by curve_by columns alphabetically for consistent legend ordering
|
||||
if curve_by:
|
||||
df = df.sort_values(by=curve_by)
|
||||
|
||||
df["row_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in row_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if row_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
df["col_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in col_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if col_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
if len(curve_by) <= 3:
|
||||
hue, style, size, *_ = (*curve_by, None, None, None)
|
||||
|
||||
g = sns.relplot(
|
||||
df,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue=hue,
|
||||
style=style,
|
||||
size=size,
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
kind="line",
|
||||
row="row_group",
|
||||
col="col_group",
|
||||
height=fig_height,
|
||||
)
|
||||
else:
|
||||
df["curve_group"] = (
|
||||
pd.concat(
|
||||
[k + "=" + df[k].astype(str) for k in curve_by],
|
||||
axis=1,
|
||||
).agg("\n".join, axis=1)
|
||||
if curve_by
|
||||
else "(All)"
|
||||
)
|
||||
|
||||
g = sns.relplot(
|
||||
df,
|
||||
x=var_x,
|
||||
y=var_y,
|
||||
hue="curve_group",
|
||||
markers=True,
|
||||
errorbar="sd" if error_bars else None,
|
||||
kind="line",
|
||||
row="row_group",
|
||||
col="col_group",
|
||||
height=fig_height,
|
||||
)
|
||||
|
||||
if row_by and col_by:
|
||||
g.set_titles("{row_name}\n{col_name}")
|
||||
elif row_by:
|
||||
g.set_titles("{row_name}")
|
||||
elif col_by:
|
||||
g.set_titles("{col_name}")
|
||||
else:
|
||||
g.set_titles("")
|
||||
|
||||
if scale_x:
|
||||
g.set(xscale=scale_x)
|
||||
if scale_y:
|
||||
g.set(yscale=scale_y)
|
||||
|
||||
g.savefig(fig_path, dpi=fig_dpi)
|
||||
plt.close(g.figure)
|
||||
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot(
|
||||
output_dir: Path,
|
||||
fig_dir: Path,
|
||||
fig_by: list[str],
|
||||
row_by: list[str],
|
||||
col_by: list[str],
|
||||
curve_by: list[str],
|
||||
*,
|
||||
var_x: str,
|
||||
var_y: str,
|
||||
filter_by: PlotFilters,
|
||||
bin_by: PlotBinners,
|
||||
scale_x: str | None,
|
||||
scale_y: str | None,
|
||||
dry_run: bool,
|
||||
fig_name: str = "FIGURE",
|
||||
error_bars: bool = True,
|
||||
fig_height: float = 6.4,
|
||||
fig_dpi: int = 300,
|
||||
):
|
||||
all_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not all_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
all_data,
|
||||
key=lambda item: _get_group(item, fig_by),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
# Resolve the iterable to ensure that the workers are run
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=var_x,
|
||||
var_y=var_y,
|
||||
filter_by=filter_by,
|
||||
bin_by=bin_by,
|
||||
scale_x=scale_x,
|
||||
scale_y=scale_y,
|
||||
dry_run=dry_run,
|
||||
fig_name=fig_name,
|
||||
error_bars=error_bars,
|
||||
fig_height=fig_height,
|
||||
fig_dpi=fig_dpi,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotArgs:
|
||||
output_dir: Path
|
||||
fig_dir: Path
|
||||
fig_by: list[str]
|
||||
row_by: list[str]
|
||||
col_by: list[str]
|
||||
curve_by: list[str]
|
||||
var_x: str
|
||||
var_y: str
|
||||
filter_by: PlotFilters
|
||||
bin_by: PlotBinners
|
||||
scale_x: str | None
|
||||
scale_y: str | None
|
||||
dry_run: bool
|
||||
fig_name: str = "FIGURE"
|
||||
error_bars: bool = True
|
||||
fig_height: float = 6.4
|
||||
fig_dpi: int = 300
|
||||
|
||||
parser_name: ClassVar[str] = "plot"
|
||||
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
curve_by = [] if not args.curve_by else args.curve_by.split(",")
|
||||
row_by = [] if not args.row_by else args.row_by.split(",")
|
||||
col_by = [] if not args.col_by else args.col_by.split(",")
|
||||
fig_by = [] if not args.fig_by else args.fig_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
fig_dir=output_dir / args.fig_dir,
|
||||
fig_by=fig_by,
|
||||
row_by=row_by,
|
||||
col_by=col_by,
|
||||
curve_by=curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=PlotFilters.parse_str(args.filter_by),
|
||||
bin_by=PlotBinners.parse_str(args.bin_by),
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=not args.no_error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the results to plot, "
|
||||
"i.e., the `--output-dir` argument to the parameter sweep script.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory to save the figures, relative to `OUTPUT_DIR`. "
|
||||
"By default, the same directory is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate figure "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--row-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate row "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--col-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of variables, such that a separate column "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curve-by",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A comma-separated list of variables, such that a separate curve "
|
||||
"is created for each combination of these variables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-x",
|
||||
type=str,
|
||||
default="request_throughput",
|
||||
help="The variable for the x-axis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--var-y",
|
||||
type=str,
|
||||
default="p99_ttft_ms",
|
||||
help="The variable for the y-axis",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to filter by. "
|
||||
"This is useful to remove outliers. "
|
||||
"Example: `max_concurrency<1000,max_num_batched_tokens<=4096` means "
|
||||
"plot only the points where `max_concurrency` is less than 1000 and "
|
||||
"`max_num_batched_tokens` is no greater than 4096.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bin-by",
|
||||
type=str,
|
||||
default="",
|
||||
help="A comma-separated list of statements indicating values to bin by. "
|
||||
"This is useful to avoid plotting points that are too close together. "
|
||||
"Example: `request_throughput%%1` means "
|
||||
"use a bin size of 1 for the `request_throughput` variable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-x",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the x-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale-y",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The scale to use for the y-axis. "
|
||||
"Currently only accepts string values such as 'log' and 'sqrt'. "
|
||||
"See also: https://seaborn.pydata.org/generated/seaborn.objects.Plot.scale.html",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-name",
|
||||
type=str,
|
||||
default="FIGURE",
|
||||
help="Name prefix for the output figure file. "
|
||||
"Group data is always appended when present. "
|
||||
"Default: 'FIGURE'. Example: --fig-name my_performance_plot",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-error-bars",
|
||||
action="store_true",
|
||||
help="If set, disables error bars on the plot. "
|
||||
"By default, error bars are shown.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-height",
|
||||
type=float,
|
||||
default=6.4,
|
||||
help="Height of each subplot in inches. Default: 6.4",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fig-dpi",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Resolution of the output figure in dots per inch. Default: 300",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the information about each figure to plot, "
|
||||
"then exits without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotArgs):
|
||||
return plot(
|
||||
output_dir=args.output_dir,
|
||||
fig_dir=args.fig_dir,
|
||||
fig_by=args.fig_by,
|
||||
row_by=args.row_by,
|
||||
col_by=args.col_by,
|
||||
curve_by=args.curve_by,
|
||||
var_x=args.var_x,
|
||||
var_y=args.var_y,
|
||||
filter_by=args.filter_by,
|
||||
bin_by=args.bin_by,
|
||||
scale_x=args.scale_x,
|
||||
scale_y=args.scale_y,
|
||||
dry_run=args.dry_run,
|
||||
fig_name=args.fig_name,
|
||||
error_bars=args.error_bars,
|
||||
fig_height=args.fig_height,
|
||||
fig_dpi=args.fig_dpi,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
|
||||
SweepPlotArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
399
vllm/benchmarks/sweep/plot_pareto.py
Normal file
399
vllm/benchmarks/sweep/plot_pareto.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.collection_utils import full_groupby
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .plot import DummyExecutor, _json_load_bytes
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
try:
|
||||
import seaborn as sns
|
||||
except ImportError:
|
||||
seaborn = PlaceholderModule("seaborn")
|
||||
|
||||
|
||||
def _first_present(run_data: dict[str, object], keys: list[str]):
|
||||
for key in keys:
|
||||
for candidate in {key, key.replace("_", "-"), key.replace("-", "_")}:
|
||||
if candidate in run_data:
|
||||
return run_data[candidate]
|
||||
return None
|
||||
|
||||
|
||||
def _get_numeric(
|
||||
run_data: dict[str, object],
|
||||
keys: list[str],
|
||||
*,
|
||||
allow_zero: bool = True,
|
||||
) -> float | None:
|
||||
value = _first_present(run_data, keys)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
numeric = float(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(
|
||||
f"Expected numeric value for one of {keys}, "
|
||||
f"but found {value!r} in {run_data=}"
|
||||
) from exc
|
||||
|
||||
if not allow_zero and numeric == 0:
|
||||
return None
|
||||
|
||||
return numeric
|
||||
|
||||
|
||||
def _infer_user_count(
|
||||
run_data: dict[str, object],
|
||||
user_count_var: str | None,
|
||||
) -> float | None:
|
||||
candidates = [user_count_var] if user_count_var else []
|
||||
candidates.extend(["request_rate"])
|
||||
user_count = _get_numeric(run_data, candidates, allow_zero=False)
|
||||
if user_count is not None:
|
||||
return user_count
|
||||
|
||||
# Fallback to the observed peak if configured value is missing.
|
||||
return _get_numeric(run_data, ["max_concurrent_requests"], allow_zero=False)
|
||||
|
||||
|
||||
def _infer_gpu_count(
|
||||
run_data: dict[str, object],
|
||||
gpu_count_var: str | None,
|
||||
) -> float:
|
||||
direct_candidates = [gpu_count_var] if gpu_count_var else []
|
||||
direct_gpu_count = _get_numeric(run_data, direct_candidates, allow_zero=False)
|
||||
if direct_gpu_count:
|
||||
return direct_gpu_count
|
||||
|
||||
tp_size = _get_numeric(run_data, ["tensor_parallel_size", "tp"])
|
||||
pp_size = _get_numeric(run_data, ["pipeline_parallel_size", "pp"])
|
||||
dp_size = _get_numeric(run_data, ["data_parallel_size", "dp"])
|
||||
world_size = 1.0
|
||||
if tp_size:
|
||||
world_size *= tp_size
|
||||
if pp_size:
|
||||
world_size *= pp_size
|
||||
if dp_size:
|
||||
world_size *= dp_size
|
||||
|
||||
return world_size
|
||||
|
||||
|
||||
def _get_throughput(
|
||||
run_data: dict[str, object],
|
||||
throughput_var: str,
|
||||
) -> float:
|
||||
throughput = _get_numeric(run_data, [throughput_var])
|
||||
if throughput is None:
|
||||
raise ValueError(
|
||||
f"Cannot find throughput metric {throughput_var!r} in run data. "
|
||||
f"Available keys: {sorted(run_data)}"
|
||||
)
|
||||
|
||||
return throughput
|
||||
|
||||
|
||||
def _prepare_records(
|
||||
all_data: list[dict[str, object]],
|
||||
*,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
) -> tuple[list[dict[str, object]], int]:
|
||||
prepared = []
|
||||
skipped_missing_users = 0
|
||||
|
||||
for record in all_data:
|
||||
throughput = _get_throughput(record, "output_throughput")
|
||||
user_count = _infer_user_count(record, user_count_var)
|
||||
if user_count is None:
|
||||
skipped_missing_users += 1
|
||||
continue
|
||||
|
||||
gpu_count = _infer_gpu_count(record, gpu_count_var)
|
||||
tokens_per_user = throughput / user_count
|
||||
tokens_per_gpu = throughput / gpu_count
|
||||
|
||||
prepared.append(
|
||||
{
|
||||
**record,
|
||||
"tokens_per_user": tokens_per_user,
|
||||
"tokens_per_gpu": tokens_per_gpu,
|
||||
"user_count_estimate": user_count,
|
||||
"gpu_count": gpu_count,
|
||||
}
|
||||
)
|
||||
|
||||
return prepared, skipped_missing_users
|
||||
|
||||
|
||||
def _pareto_frontier(
|
||||
df: "pd.DataFrame",
|
||||
x_col: str,
|
||||
y_col: str,
|
||||
*,
|
||||
epsilon: float = 1e-9,
|
||||
) -> "pd.DataFrame":
|
||||
sorted_df = df.sort_values([x_col, y_col], ascending=[False, False])
|
||||
frontier_indices = []
|
||||
best_y = -math.inf
|
||||
|
||||
for idx, row in sorted_df.iterrows():
|
||||
y_val = row[y_col]
|
||||
if y_val >= best_y - epsilon:
|
||||
frontier_indices.append(idx)
|
||||
best_y = max(best_y, y_val)
|
||||
|
||||
return df.loc[frontier_indices]
|
||||
|
||||
|
||||
def _get_fig_path(
|
||||
fig_dir: Path,
|
||||
fig_group: tuple[tuple[str, str], ...],
|
||||
) -> Path:
|
||||
parts = ["PARETO"]
|
||||
if fig_group:
|
||||
parts.extend(f"{k}={v}" for k, v in fig_group)
|
||||
filename = sanitize_filename("-".join(parts) + ".png")
|
||||
return fig_dir / filename
|
||||
|
||||
|
||||
def _plot_fig(
|
||||
fig_dir: Path,
|
||||
fig_group_data: tuple[tuple[tuple[str, str], ...], list[dict[str, object]]],
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_group, fig_data = fig_group_data
|
||||
fig_path = _get_fig_path(fig_dir, fig_group)
|
||||
|
||||
print("[BEGIN FIGURE]")
|
||||
print(f"Group: {dict(fig_group)}")
|
||||
print(f"Output file: {fig_path}")
|
||||
|
||||
if dry_run:
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
df = pd.DataFrame.from_records(fig_data)
|
||||
df = df.dropna(subset=["tokens_per_user", "tokens_per_gpu"])
|
||||
|
||||
if df.empty:
|
||||
print("No data points available after filtering; skipping.")
|
||||
print("[END FIGURE]")
|
||||
return
|
||||
|
||||
frontier = _pareto_frontier(df, "tokens_per_user", "tokens_per_gpu")
|
||||
frontier = frontier.sort_values("tokens_per_user")
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
sns.scatterplot(
|
||||
data=df,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
color="0.5",
|
||||
alpha=0.6,
|
||||
ax=ax,
|
||||
label="All runs",
|
||||
)
|
||||
sns.lineplot(
|
||||
data=frontier,
|
||||
x="tokens_per_user",
|
||||
y="tokens_per_gpu",
|
||||
marker="o",
|
||||
ax=ax,
|
||||
label="Pareto frontier",
|
||||
)
|
||||
|
||||
if label_by:
|
||||
for _, row in frontier.iterrows():
|
||||
label_parts = []
|
||||
for key in label_by:
|
||||
if key in row:
|
||||
label_parts.append(f"{key}={row[key]}")
|
||||
if label_parts:
|
||||
ax.text(
|
||||
row["tokens_per_user"],
|
||||
row["tokens_per_gpu"],
|
||||
"\n".join(label_parts),
|
||||
fontsize=8,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Tokens/s/user")
|
||||
ax.set_ylabel("Tokens/s/GPU")
|
||||
ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
|
||||
fig.tight_layout()
|
||||
fig.savefig(fig_path)
|
||||
plt.close(fig)
|
||||
|
||||
print(
|
||||
f"Plotted {len(df)} points; Pareto frontier size: {len(frontier)}.",
|
||||
)
|
||||
print("[END FIGURE]")
|
||||
|
||||
|
||||
def plot_pareto(
|
||||
output_dir: Path,
|
||||
user_count_var: str | None,
|
||||
gpu_count_var: str | None,
|
||||
label_by: list[str],
|
||||
*,
|
||||
dry_run: bool,
|
||||
):
|
||||
fig_dir = output_dir / "pareto"
|
||||
raw_data = [
|
||||
run_data
|
||||
for path in output_dir.rglob("**/summary.json")
|
||||
for run_data in _json_load_bytes(path)
|
||||
]
|
||||
|
||||
if not raw_data:
|
||||
raise ValueError(f"Did not find any parameter sweep results under {output_dir}")
|
||||
|
||||
fig_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prepared_data, skipped_missing_users = _prepare_records(
|
||||
raw_data,
|
||||
user_count_var=user_count_var,
|
||||
gpu_count_var=gpu_count_var,
|
||||
)
|
||||
|
||||
if skipped_missing_users:
|
||||
print(
|
||||
f"Skipped {skipped_missing_users} runs without a user count "
|
||||
"(`max_concurrency` or `max_concurrent_requests`).",
|
||||
)
|
||||
|
||||
if not prepared_data:
|
||||
raise ValueError(
|
||||
"No data points with both throughput and user count available "
|
||||
"to plot Pareto frontier.",
|
||||
)
|
||||
|
||||
fig_groups = full_groupby(
|
||||
prepared_data,
|
||||
key=lambda item: tuple(),
|
||||
)
|
||||
|
||||
with DummyExecutor() if len(fig_groups) <= 1 else ProcessPoolExecutor() as executor:
|
||||
all(
|
||||
executor.map(
|
||||
partial(
|
||||
_plot_fig,
|
||||
fig_dir,
|
||||
label_by=label_by,
|
||||
dry_run=dry_run,
|
||||
),
|
||||
fig_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepPlotParetoArgs:
|
||||
output_dir: Path
|
||||
user_count_var: str | None
|
||||
gpu_count_var: str | None
|
||||
label_by: list[str]
|
||||
dry_run: bool
|
||||
|
||||
parser_name: ClassVar[str] = "plot_pareto"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Plot Pareto frontier between tokens/s/user and tokens/s/GPU "
|
||||
"from parameter sweep results."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
output_dir = Path(args.OUTPUT_DIR)
|
||||
if not output_dir.exists():
|
||||
raise ValueError(f"No parameter sweep results under {output_dir}")
|
||||
|
||||
label_by = [] if not args.label_by else args.label_by.split(",")
|
||||
|
||||
return cls(
|
||||
output_dir=output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"OUTPUT_DIR",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory containing the sweep results to plot.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--user-count-var",
|
||||
type=str,
|
||||
default="max_concurrency",
|
||||
help="Result key that stores concurrent user count. "
|
||||
"Falls back to max_concurrent_requests if missing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-count-var",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Result key that stores GPU count. "
|
||||
"If not provided, falls back to num_gpus/gpu_count "
|
||||
"or tensor_parallel_size * pipeline_parallel_size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label-by",
|
||||
type=str,
|
||||
default="max_concurrency,gpu_count",
|
||||
help="Comma-separated list of fields to annotate on Pareto frontier "
|
||||
"points.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the figures to plot without drawing them.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepPlotParetoArgs):
|
||||
return plot_pareto(
|
||||
output_dir=args.output_dir,
|
||||
user_count_var=args.user_count_var,
|
||||
gpu_count_var=args.gpu_count_var,
|
||||
label_by=args.label_by,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepPlotParetoArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepPlotParetoArgs.parser_help)
|
||||
SweepPlotParetoArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
498
vllm/benchmarks/sweep/serve.py
Normal file
498
vllm/benchmarks/sweep/serve.py
Normal file
@@ -0,0 +1,498 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import shlex
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .server import ServerProcess
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_server(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
server_cmd = serve_overrides.apply_to_cmd(serve_cmd)
|
||||
|
||||
print("[BEGIN SERVER]")
|
||||
print(f"Server overrides: {serve_overrides}")
|
||||
print(f"Server command: {server_cmd}")
|
||||
|
||||
if dry_run:
|
||||
yield None
|
||||
print("[END SERVER]")
|
||||
return
|
||||
|
||||
with ServerProcess(server_cmd, after_bench_cmd, show_stdout=show_stdout) as server:
|
||||
server.wait_until_ready(timeout=server_ready_timeout)
|
||||
yield server
|
||||
|
||||
print("[END SERVER]")
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
):
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(bench_overrides)
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
bench_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
dry_run: bool,
|
||||
):
|
||||
benchmark_cmd = [
|
||||
*bench_overrides.apply_to_cmd(bench_cmd),
|
||||
"--percentile-metrics",
|
||||
"ttft,tpot,itl,e2el",
|
||||
"--save-result",
|
||||
"--result-dir",
|
||||
str(output_path.parent),
|
||||
"--result-filename",
|
||||
output_path.name,
|
||||
]
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Benchmark overrides: {bench_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {benchmark_cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
run_data: dict[str, object]
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results.")
|
||||
print("[SKIPPED BENCHMARK]")
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
if server is None:
|
||||
if not dry_run:
|
||||
raise ValueError(f"Cannot find results at {output_path}")
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
server.run_subcommand(benchmark_cmd)
|
||||
server.after_bench()
|
||||
|
||||
with output_path.open("rb") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data,
|
||||
serve_overrides,
|
||||
bench_overrides,
|
||||
run_number,
|
||||
)
|
||||
|
||||
with output_path.open("w") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
|
||||
return run_data
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
):
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if bench_comb:
|
||||
parts.extend(("BENCH-", bench_comb.name))
|
||||
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None):
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def _comb_needs_server(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_combs: ParameterSweep,
|
||||
output_dir: Path,
|
||||
):
|
||||
for bench_comb in bench_combs:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
if not _get_comb_run_path(base_path, run_number=None).exists():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def server_ctx(
|
||||
serve_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
dry_run: bool,
|
||||
server_ready_timeout: int = 300,
|
||||
):
|
||||
if not _comb_needs_server(serve_comb, bench_params, output_dir):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
return run_server(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_overrides=serve_comb,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
)
|
||||
|
||||
|
||||
def _comb_is_valid(
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
link_vars: list[tuple[str, str]],
|
||||
) -> bool:
|
||||
return all(
|
||||
serve_key in serve_comb
|
||||
and bench_key in bench_comb
|
||||
and serve_comb[serve_key] == bench_comb[bench_key]
|
||||
for serve_key, bench_key in link_vars
|
||||
)
|
||||
|
||||
|
||||
def run_comb(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
|
||||
return None
|
||||
|
||||
comb_data = list[dict[str, object]]()
|
||||
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
bench_overrides=bench_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open("w") as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
|
||||
|
||||
comb_data = run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeArgs:
|
||||
serve_cmd: list[str]
|
||||
bench_cmd: list[str]
|
||||
after_bench_cmd: list[str]
|
||||
show_stdout: bool
|
||||
server_ready_timeout: int
|
||||
serve_params: ParameterSweep
|
||||
bench_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
link_vars: list[tuple[str, str]]
|
||||
|
||||
parser_name: ClassVar[str] = "serve"
|
||||
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
serve_cmd = shlex.split(args.serve_cmd)
|
||||
bench_cmd = shlex.split(args.bench_cmd)
|
||||
after_bench_cmd = (
|
||||
[] if args.after_bench_cmd is None else shlex.split(args.after_bench_cmd)
|
||||
)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
# i.e.: run serve_cmd without any modification
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.bench_params:
|
||||
bench_params = ParameterSweep.read_json(args.bench_params)
|
||||
else:
|
||||
# i.e.: run bench_cmd without any modification
|
||||
bench_params = ParameterSweep.from_records([{}])
|
||||
|
||||
link_vars = cls.parse_link_vars(args.link_vars)
|
||||
|
||||
num_runs = args.num_runs
|
||||
if num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
serve_cmd=serve_cmd,
|
||||
bench_cmd=bench_cmd,
|
||||
after_bench_cmd=after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
serve_params=serve_params,
|
||||
bench_params=bench_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=num_runs,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
link_vars=link_vars,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--serve-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the server: `vllm serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-cmd",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The command used to run the benchmark: `vllm bench serve ...`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--after-bench-cmd",
|
||||
type=str,
|
||||
default=None,
|
||||
help="After a benchmark run is complete, invoke this command instead of "
|
||||
"the default `ServerWrapper.clear_cache()`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands. "
|
||||
"Useful for debugging but can be quite spammy.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server-ready-timeout",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Timeout in seconds to wait for the server to become ready.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Can be either a list of dicts or a dict "
|
||||
"where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bench-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench serve` command. Can be either a list of dicts or "
|
||||
"a dict where keys are benchmark names. "
|
||||
"If both `serve_params` and `bench_params` are given, "
|
||||
"this script will iterate over their Cartesian product.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--link-vars",
|
||||
type=str,
|
||||
default="",
|
||||
help=(
|
||||
"Comma-separated list of linked variables between serve and bench, "
|
||||
"e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
|
||||
),
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def parse_link_vars(s: str) -> list[tuple[str, str]]:
|
||||
if not s:
|
||||
return []
|
||||
pairs = []
|
||||
for item in s.split(","):
|
||||
a, b = item.split("=")
|
||||
pairs.append((a.strip(), b.strip()))
|
||||
return pairs
|
||||
|
||||
|
||||
def run_main(args: SweepServeArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
|
||||
SweepServeArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
305
vllm/benchmarks/sweep/serve_sla.py
Normal file
305
vllm/benchmarks/sweep/serve_sla.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .serve import (
|
||||
SweepServeArgs,
|
||||
_get_comb_base_path,
|
||||
run_comb,
|
||||
server_ctx,
|
||||
)
|
||||
from .server import ServerProcess
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
SLAVariable = Literal["request_rate", "max_concurrency"]
|
||||
|
||||
|
||||
def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
|
||||
request_throughput = float(run_data["request_throughput"]) # type: ignore
|
||||
if sla_variable == "request_rate":
|
||||
return request_throughput
|
||||
if sla_variable == "max_concurrency":
|
||||
mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
|
||||
return request_throughput * mean_latency_ms / 1000
|
||||
|
||||
assert_never(sla_variable)
|
||||
|
||||
|
||||
def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
|
||||
return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
|
||||
|
||||
|
||||
def run_comb_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
sla_variable: SLAVariable,
|
||||
sla_value: int,
|
||||
) -> list[dict[str, object]] | None:
|
||||
bench_comb_sla = bench_comb | {sla_variable: sla_value}
|
||||
|
||||
return run_comb(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb_sla,
|
||||
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
|
||||
def explore_sla(
|
||||
server: ServerProcess | None,
|
||||
bench_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
bench_comb: ParameterSweepItem,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
print("[SLA START]")
|
||||
print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
|
||||
print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
|
||||
print(f"Number of SLA iterations: {sla_iters}")
|
||||
|
||||
if sla_iters < 2:
|
||||
raise ValueError("`sla_iters` should be at least 2")
|
||||
|
||||
serial_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=1,
|
||||
)
|
||||
batch_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
|
||||
)
|
||||
|
||||
if serial_comb_data is None or batch_comb_data is None:
|
||||
if dry_run:
|
||||
print("Omitting intermediate SLA iterations.")
|
||||
print("[SLA END]")
|
||||
|
||||
return
|
||||
|
||||
serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
|
||||
print(f"Serial inference: {sla_variable}={serial_sla_value}")
|
||||
|
||||
batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
|
||||
print(f"Batch inference: {sla_variable}={batch_sla_value}")
|
||||
|
||||
# Avoid duplicated runs for intermediate values if the range between
|
||||
# `serial_sla_value` and `batch_sla_value` is small
|
||||
inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
|
||||
inter_sla_values = sorted(set(map(round, inter_sla_values)))
|
||||
|
||||
inter_combs_data: list[dict[str, object]] = []
|
||||
for inter_sla_value in inter_sla_values:
|
||||
print(f"Exploring: {sla_variable}={inter_sla_value}")
|
||||
inter_comb_data = run_comb_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
sla_variable=sla_variable,
|
||||
sla_value=inter_sla_value,
|
||||
)
|
||||
if inter_comb_data is not None:
|
||||
inter_combs_data.extend(inter_comb_data)
|
||||
|
||||
print("[SLA END]")
|
||||
|
||||
return serial_comb_data + inter_combs_data + batch_comb_data
|
||||
|
||||
|
||||
def run_slas(
|
||||
serve_cmd: list[str],
|
||||
bench_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
server_ready_timeout: int,
|
||||
serve_params: ParameterSweep,
|
||||
bench_params: ParameterSweep,
|
||||
sla_variable: SLAVariable,
|
||||
sla_iters: int,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
dry_run: bool,
|
||||
link_vars: list[tuple[str, str]],
|
||||
):
|
||||
if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
|
||||
raise ValueError(
|
||||
f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
|
||||
"since it is supposed to be determined automatically."
|
||||
)
|
||||
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
with server_ctx(
|
||||
serve_cmd,
|
||||
after_bench_cmd,
|
||||
show_stdout=show_stdout,
|
||||
server_ready_timeout=server_ready_timeout,
|
||||
serve_comb=serve_comb,
|
||||
bench_params=bench_params,
|
||||
output_dir=output_dir,
|
||||
dry_run=dry_run,
|
||||
) as server:
|
||||
for bench_comb in bench_params:
|
||||
comb_data = explore_sla(
|
||||
server,
|
||||
bench_cmd,
|
||||
serve_comb=serve_comb,
|
||||
bench_comb=bench_comb,
|
||||
sla_variable=sla_variable,
|
||||
sla_iters=sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=num_runs,
|
||||
dry_run=dry_run,
|
||||
link_vars=link_vars,
|
||||
)
|
||||
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepServeSLAArgs(SweepServeArgs):
|
||||
sla_variable: SLAVariable
|
||||
sla_iters: int
|
||||
|
||||
parser_name: ClassVar[str] = "serve_sla"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Explore the latency-throughput space for determining SLAs."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# NOTE: Don't use super() as `from_cli_args` calls `cls()`
|
||||
base_args = SweepServeArgs.from_cli_args(args)
|
||||
|
||||
return cls(
|
||||
**asdict(base_args),
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = super().add_cli_args(parser)
|
||||
|
||||
sla_group = parser.add_argument_group("sla options")
|
||||
sla_group.add_argument(
|
||||
"--sla-variable",
|
||||
type=str,
|
||||
choices=get_args(SLAVariable),
|
||||
default="request_rate",
|
||||
help="The variable to adjust in each iteration.",
|
||||
)
|
||||
sla_group.add_argument(
|
||||
"--sla-iters",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations used to explore the latency-throughput space. "
|
||||
"This includes the first two iterations used to interpolate the value of "
|
||||
"`sla_variable` for remaining iterations.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepServeSLAArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_slas(
|
||||
serve_cmd=args.serve_cmd,
|
||||
bench_cmd=args.bench_cmd,
|
||||
after_bench_cmd=args.after_bench_cmd,
|
||||
show_stdout=args.show_stdout,
|
||||
server_ready_timeout=args.server_ready_timeout,
|
||||
serve_params=args.serve_params,
|
||||
bench_params=args.bench_params,
|
||||
sla_variable=args.sla_variable,
|
||||
sla_iters=args.sla_iters,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
dry_run=args.dry_run,
|
||||
link_vars=args.link_vars,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepServeSLAArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
|
||||
SweepServeSLAArgs.add_cli_args(parser)
|
||||
|
||||
main(parser.parse_args())
|
||||
142
vllm/benchmarks/sweep/server.py
Normal file
142
vllm/benchmarks/sweep/server.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
from types import TracebackType
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class ServerProcess:
|
||||
VLLM_RESET_CACHE_ENDPOINTS = [
|
||||
"/reset_prefix_cache",
|
||||
"/reset_mm_cache",
|
||||
"/reset_encoder_cache",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_cmd: list[str],
|
||||
after_bench_cmd: list[str],
|
||||
*,
|
||||
show_stdout: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.server_cmd = server_cmd
|
||||
self.after_bench_cmd = after_bench_cmd
|
||||
self.show_stdout = show_stdout
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
exc_traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.stop()
|
||||
|
||||
def start(self):
|
||||
# Create new process for clean termination
|
||||
self._server_process = subprocess.Popen(
|
||||
self.server_cmd,
|
||||
start_new_session=True,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
# Need `VLLM_SERVER_DEV_MODE=1` for `_reset_caches`
|
||||
env=os.environ | {"VLLM_SERVER_DEV_MODE": "1"},
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
server_process = self._server_process
|
||||
|
||||
if server_process.poll() is None:
|
||||
# In case only some processes have been terminated
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
# We need to kill both API Server and Engine processes
|
||||
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
|
||||
|
||||
def run_subcommand(self, cmd: list[str]):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
stdout=None if self.show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
def after_bench(self) -> None:
|
||||
if not self.after_bench_cmd:
|
||||
self.reset_caches()
|
||||
return
|
||||
|
||||
self.run_subcommand(self.after_bench_cmd)
|
||||
|
||||
def _get_vllm_server_address(self) -> str:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
for host_key in ("--host",):
|
||||
if host_key in server_cmd:
|
||||
host = server_cmd[server_cmd.index(host_key) + 1]
|
||||
break
|
||||
else:
|
||||
host = "localhost"
|
||||
|
||||
for port_key in ("-p", "--port"):
|
||||
if port_key in server_cmd:
|
||||
port = int(server_cmd[server_cmd.index(port_key) + 1])
|
||||
break
|
||||
else:
|
||||
port = 8000 # The default value in vllm serve
|
||||
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def is_server_ready(self) -> bool:
|
||||
server_address = self._get_vllm_server_address()
|
||||
try:
|
||||
response = requests.get(f"{server_address}/health")
|
||||
return response.status_code == 200
|
||||
except requests.RequestException:
|
||||
return False
|
||||
|
||||
def wait_until_ready(self, timeout: int) -> None:
|
||||
start_time = time.monotonic()
|
||||
while not self.is_server_ready():
|
||||
# Check if server process has crashed
|
||||
if self._server_process.poll() is not None:
|
||||
returncode = self._server_process.returncode
|
||||
raise RuntimeError(
|
||||
f"Server process crashed with return code {returncode}"
|
||||
)
|
||||
if time.monotonic() - start_time > timeout:
|
||||
raise TimeoutError(
|
||||
f"Server failed to become ready within {timeout} seconds."
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
def reset_caches(self) -> None:
|
||||
server_cmd = self.server_cmd
|
||||
|
||||
# Use `.endswith()` to match `/bin/...`
|
||||
if server_cmd[0].endswith("vllm"):
|
||||
server_address = self._get_vllm_server_address()
|
||||
print(f"Resetting caches at {server_address}")
|
||||
|
||||
for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS:
|
||||
res = requests.post(server_address + endpoint)
|
||||
res.raise_for_status()
|
||||
elif server_cmd[0].endswith("infinity_emb"):
|
||||
if "--vector-disk-cache" in server_cmd:
|
||||
raise NotImplementedError(
|
||||
"Infinity server uses caching but does not expose a method "
|
||||
"to reset the cache"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"No implementation of `reset_caches` for `{server_cmd[0]}` server. "
|
||||
"Please specify a custom command via `--after-bench-cmd`."
|
||||
)
|
||||
406
vllm/benchmarks/sweep/startup.py
Normal file
406
vllm/benchmarks/sweep/startup.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from vllm.benchmarks.startup import add_cli_args as add_startup_cli_args
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
from .param_sweep import ParameterSweep, ParameterSweepItem
|
||||
from .utils import sanitize_filename
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = PlaceholderModule("pandas")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_supported_startup_keys() -> set[str]:
|
||||
parser = FlexibleArgumentParser(add_help=False)
|
||||
add_startup_cli_args(parser)
|
||||
|
||||
supported: set[str] = {"config"}
|
||||
for action in parser._actions:
|
||||
if action.dest and action.dest is not argparse.SUPPRESS:
|
||||
supported.add(action.dest)
|
||||
for option in action.option_strings:
|
||||
if option.startswith("--"):
|
||||
supported.add(option.lstrip("-").replace("-", "_"))
|
||||
|
||||
return supported
|
||||
|
||||
|
||||
def _is_supported_param(param_key: str, supported: set[str]) -> bool:
|
||||
if param_key == "_benchmark_name":
|
||||
return True
|
||||
prefix = param_key.split(".", 1)[0]
|
||||
normalized = prefix.replace("-", "_")
|
||||
return normalized in supported
|
||||
|
||||
|
||||
def _filter_params(
|
||||
params: ParameterSweep, *, supported: set[str], strict: bool
|
||||
) -> ParameterSweep:
|
||||
filtered = []
|
||||
for item in params:
|
||||
kept: dict[str, object] = {}
|
||||
dropped: list[str] = []
|
||||
for key, value in item.items():
|
||||
if _is_supported_param(key, supported):
|
||||
kept[key] = value
|
||||
else:
|
||||
dropped.append(key)
|
||||
|
||||
if dropped:
|
||||
label = item.get("_benchmark_name") or item.as_text()
|
||||
message = (
|
||||
"Ignoring unsupported startup params"
|
||||
f"{' for ' + str(label) if label else ''}: "
|
||||
f"{', '.join(sorted(dropped))}"
|
||||
)
|
||||
if strict:
|
||||
raise ValueError(message)
|
||||
print(message)
|
||||
|
||||
filtered.append(ParameterSweepItem.from_record(kept))
|
||||
|
||||
return ParameterSweep(filtered)
|
||||
|
||||
|
||||
def _update_run_data(
|
||||
run_data: dict[str, object],
|
||||
serve_overrides: ParameterSweepItem,
|
||||
startup_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
) -> dict[str, object]:
|
||||
run_data["run_number"] = run_number
|
||||
run_data.update(serve_overrides)
|
||||
run_data.update(startup_overrides)
|
||||
return run_data
|
||||
|
||||
|
||||
def _strip_arg(cmd: list[str], keys: tuple[str, ...]) -> list[str]:
|
||||
stripped: list[str] = []
|
||||
skip_next = False
|
||||
for arg in cmd:
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
if arg in keys:
|
||||
skip_next = True
|
||||
continue
|
||||
if any(arg.startswith(f"{key}=") for key in keys):
|
||||
continue
|
||||
stripped.append(arg)
|
||||
return stripped
|
||||
|
||||
|
||||
def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
|
||||
keys = ("--output-json", "--output_json")
|
||||
cmd = _strip_arg(cmd, keys)
|
||||
return [*cmd, keys[0], str(output_path)]
|
||||
|
||||
|
||||
def _get_comb_base_path(
|
||||
output_dir: Path,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
) -> Path:
|
||||
parts = list[str]()
|
||||
if serve_comb:
|
||||
parts.extend(("SERVE-", serve_comb.name))
|
||||
if startup_comb:
|
||||
parts.extend(("STARTUP-", startup_comb.name))
|
||||
return output_dir / sanitize_filename("-".join(parts))
|
||||
|
||||
|
||||
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
|
||||
if run_number is None:
|
||||
return base_path / "summary.json"
|
||||
return base_path / f"run={run_number}.json"
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_overrides: ParameterSweepItem,
|
||||
startup_overrides: ParameterSweepItem,
|
||||
run_number: int,
|
||||
output_path: Path,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> dict[str, object] | None:
|
||||
cmd = serve_overrides.apply_to_cmd(startup_cmd)
|
||||
cmd = startup_overrides.apply_to_cmd(cmd)
|
||||
cmd = _apply_output_json(cmd, output_path)
|
||||
|
||||
print("[BEGIN BENCHMARK]")
|
||||
print(f"Serve overrides: {serve_overrides}")
|
||||
print(f"Startup overrides: {startup_overrides}")
|
||||
print(f"Run Number: {run_number}")
|
||||
print(f"Benchmark command: {cmd}")
|
||||
print(f"Output file: {output_path}")
|
||||
|
||||
if output_path.exists():
|
||||
print("Found existing results.")
|
||||
print("[SKIPPED BENCHMARK]")
|
||||
|
||||
with output_path.open("r", encoding="utf-8") as f:
|
||||
run_data = json.load(f)
|
||||
return _update_run_data(
|
||||
run_data, serve_overrides, startup_overrides, run_number
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
print("[END BENCHMARK]")
|
||||
return None
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
subprocess.run(
|
||||
cmd,
|
||||
stdout=None if show_stdout else subprocess.DEVNULL,
|
||||
check=True,
|
||||
)
|
||||
|
||||
with output_path.open("r", encoding="utf-8") as f:
|
||||
run_data = json.load(f)
|
||||
|
||||
run_data = _update_run_data(
|
||||
run_data, serve_overrides, startup_overrides, run_number
|
||||
)
|
||||
|
||||
with output_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(run_data, f, indent=4)
|
||||
|
||||
print("[END BENCHMARK]")
|
||||
return run_data
|
||||
|
||||
|
||||
def run_comb(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_comb: ParameterSweepItem,
|
||||
startup_comb: ParameterSweepItem,
|
||||
base_path: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> list[dict[str, object]] | None:
|
||||
comb_data = list[dict[str, object]]()
|
||||
for run_number in range(num_runs):
|
||||
run_data = run_benchmark(
|
||||
startup_cmd,
|
||||
serve_overrides=serve_comb,
|
||||
startup_overrides=startup_comb,
|
||||
run_number=run_number,
|
||||
output_path=_get_comb_run_path(base_path, run_number),
|
||||
show_stdout=show_stdout,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if run_data is not None:
|
||||
comb_data.append(run_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
with _get_comb_run_path(base_path, run_number=None).open(
|
||||
"w", encoding="utf-8"
|
||||
) as f:
|
||||
json.dump(comb_data, f, indent=4)
|
||||
|
||||
return comb_data
|
||||
|
||||
|
||||
def run_combs(
|
||||
startup_cmd: list[str],
|
||||
*,
|
||||
serve_params: ParameterSweep,
|
||||
startup_params: ParameterSweep,
|
||||
output_dir: Path,
|
||||
num_runs: int,
|
||||
show_stdout: bool,
|
||||
dry_run: bool,
|
||||
) -> "pd.DataFrame | None":
|
||||
all_data = list[dict[str, object]]()
|
||||
for serve_comb in serve_params:
|
||||
for startup_comb in startup_params:
|
||||
base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
|
||||
comb_data = run_comb(
|
||||
startup_cmd,
|
||||
serve_comb=serve_comb,
|
||||
startup_comb=startup_comb,
|
||||
base_path=base_path,
|
||||
num_runs=num_runs,
|
||||
show_stdout=show_stdout,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
if comb_data is not None:
|
||||
all_data.extend(comb_data)
|
||||
|
||||
if dry_run:
|
||||
return None
|
||||
|
||||
combined_df = pd.DataFrame.from_records(all_data)
|
||||
combined_df.to_csv(output_dir / "summary.csv")
|
||||
return combined_df
|
||||
|
||||
|
||||
@dataclass
|
||||
class SweepStartupArgs:
|
||||
startup_cmd: list[str]
|
||||
serve_params: ParameterSweep
|
||||
startup_params: ParameterSweep
|
||||
output_dir: Path
|
||||
num_runs: int
|
||||
show_stdout: bool
|
||||
dry_run: bool
|
||||
resume: str | None
|
||||
strict_params: bool
|
||||
|
||||
parser_name: ClassVar[str] = "startup"
|
||||
parser_help: ClassVar[str] = (
|
||||
"Benchmark vLLM startup time over parameter combinations."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
startup_cmd = shlex.split(args.startup_cmd)
|
||||
|
||||
if args.serve_params:
|
||||
serve_params = ParameterSweep.read_json(args.serve_params)
|
||||
else:
|
||||
serve_params = ParameterSweep.from_records([{}])
|
||||
|
||||
if args.startup_params:
|
||||
startup_params = ParameterSweep.read_json(args.startup_params)
|
||||
else:
|
||||
startup_params = ParameterSweep.from_records([{}])
|
||||
|
||||
supported = _get_supported_startup_keys()
|
||||
serve_params = _filter_params(
|
||||
serve_params, supported=supported, strict=args.strict_params
|
||||
)
|
||||
startup_params = _filter_params(
|
||||
startup_params, supported=supported, strict=args.strict_params
|
||||
)
|
||||
|
||||
if args.num_runs < 1:
|
||||
raise ValueError("`num_runs` should be at least 1.")
|
||||
|
||||
return cls(
|
||||
startup_cmd=startup_cmd,
|
||||
serve_params=serve_params,
|
||||
startup_params=startup_params,
|
||||
output_dir=Path(args.output_dir),
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
resume=args.resume,
|
||||
strict_params=args.strict_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--startup-cmd",
|
||||
type=str,
|
||||
default="vllm bench startup",
|
||||
help="The command used to run the startup benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serve-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm serve` command. Only parameters supported by "
|
||||
"`vllm bench startup` will be applied.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--startup-params",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to JSON file containing parameter combinations "
|
||||
"for the `vllm bench startup` command.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="results",
|
||||
help="The directory to which results are written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-runs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of runs per parameter combination.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show-stdout",
|
||||
action="store_true",
|
||||
help="If set, logs the standard output of subcommands.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="If set, prints the commands to run, "
|
||||
"then exits without executing them.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set this to the name of a directory under `output_dir` (which is a "
|
||||
"timestamp) to resume a previous execution of this script, i.e., only run "
|
||||
"parameter combinations for which there are still no output files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strict-params",
|
||||
action="store_true",
|
||||
help="If set, unknown parameters in sweep files raise an error "
|
||||
"instead of being ignored.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def run_main(args: SweepStartupArgs):
|
||||
timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_dir = args.output_dir / timestamp
|
||||
|
||||
if args.resume and not output_dir.exists():
|
||||
raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
|
||||
|
||||
try:
|
||||
return run_combs(
|
||||
startup_cmd=args.startup_cmd,
|
||||
serve_params=args.serve_params,
|
||||
startup_params=args.startup_params,
|
||||
output_dir=output_dir,
|
||||
num_runs=args.num_runs,
|
||||
show_stdout=args.show_stdout,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise RuntimeError(
|
||||
f"The script was terminated early. Use `--resume {timestamp}` "
|
||||
f"to continue the script from its last checkpoint."
|
||||
) from exc
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
run_main(SweepStartupArgs.from_cli_args(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=SweepStartupArgs.parser_help)
|
||||
SweepStartupArgs.add_cli_args(parser)
|
||||
main(parser.parse_args())
|
||||
4
vllm/benchmarks/sweep/utils.py
Normal file
4
vllm/benchmarks/sweep/utils.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
return filename.replace("/", "_").replace("..", "__").strip("'").strip('"')
|
||||
946
vllm/benchmarks/throughput.py
Normal file
946
vllm/benchmarks/throughput.py
Normal file
@@ -0,0 +1,946 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark offline inference throughput."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (
|
||||
AIMODataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
MultiModalConversationDataset,
|
||||
PrefixRepetitionRandomDataset,
|
||||
RandomDataset,
|
||||
RandomDatasetForReranking,
|
||||
RandomMultiModalDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
add_random_dataset_base_args,
|
||||
add_random_multimodal_dataset_args,
|
||||
)
|
||||
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.tokenizers import TokenizerLike, get_tokenizer
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput] | None]:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
prompts.append(prompt)
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests: list[LoRARequest] | None = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
do_profile: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
llm.start_profile()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
if do_profile:
|
||||
llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
do_profile: bool,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = llm.model_config
|
||||
assert all(
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt | TokensPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[LoRARequest | None] = []
|
||||
for request in requests:
|
||||
prompt = (
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"])
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(prompt=request.prompt)
|
||||
)
|
||||
|
||||
if request.multi_modal_data:
|
||||
assert isinstance(request.multi_modal_data, dict)
|
||||
prompt["multi_modal_data"] = request.multi_modal_data
|
||||
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
prompts.append(prompt)
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
if do_profile:
|
||||
await llm.start_profile()
|
||||
for i, (prompt, sp, lr) in enumerate(
|
||||
zip(prompts, sampling_params, lora_requests)
|
||||
):
|
||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
if do_profile:
|
||||
await llm.stop_profile()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: TokenizerLike,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
assert isinstance(tokenizer, PreTrainedTokenizerBase), (
|
||||
"the hf backend only supports HF tokenizers"
|
||||
)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (
|
||||
max(max_prompt_len, next_prompt_len)
|
||||
+ max(max_output_len, next_output_len)
|
||||
) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"requests_per_second": [results["requests_per_second"]],
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
}
|
||||
|
||||
if args.dataset_name == "random" or (
|
||||
args.dataset_path is None
|
||||
and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
|
||||
):
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
sample_kwargs["prefix_len"] = (
|
||||
random_prefix_len if random_prefix_len is not None else args.prefix_len
|
||||
)
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len if random_input_len is not None else args.input_len
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len if random_output_len is not None else args.output_len
|
||||
)
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
if args.input_len is not None:
|
||||
sample_kwargs["input_len"] = args.input_len
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.output_len is not None:
|
||||
sample_kwargs["output_len"] = args.output_len
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = MultiModalConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_name == "prefix_repetition":
|
||||
dataset_cls = PrefixRepetitionRandomDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_repetition_prefix_len
|
||||
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
|
||||
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
|
||||
sample_kwargs["output_len"] = args.prefix_repetition_output_len
|
||||
elif args.dataset_name == "random-mm":
|
||||
dataset_cls = RandomMultiModalDataset
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len
|
||||
if random_input_len is not None
|
||||
else getattr(args, "input_len", None)
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len
|
||||
if random_output_len is not None
|
||||
else getattr(args, "output_len", None)
|
||||
)
|
||||
sample_kwargs["base_items_per_request"] = getattr(
|
||||
args, "random_mm_base_items_per_request", None
|
||||
)
|
||||
sample_kwargs["num_mm_items_range_ratio"] = getattr(
|
||||
args, "random_mm_num_mm_items_range_ratio", None
|
||||
)
|
||||
sample_kwargs["limit_mm_per_prompt"] = getattr(
|
||||
args, "random_mm_limit_mm_per_prompt", None
|
||||
)
|
||||
sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
prefix_len = getattr(args, "prefix_len", None)
|
||||
sample_kwargs["prefix_len"] = (
|
||||
random_prefix_len if random_prefix_len is not None else prefix_len
|
||||
)
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
elif args.dataset_name == "random-rerank":
|
||||
dataset_cls = RandomDatasetForReranking
|
||||
# prefer random_* arguments, fall back to regular arguments
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
sample_kwargs["input_len"] = (
|
||||
random_input_len
|
||||
if random_input_len is not None
|
||||
else getattr(args, "input_len", None)
|
||||
)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
sample_kwargs["output_len"] = (
|
||||
random_output_len
|
||||
if random_output_len is not None
|
||||
else getattr(args, "output_len", None)
|
||||
)
|
||||
sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
|
||||
sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
requests = filter_requests_for_dp(requests, args.data_parallel_size)
|
||||
return requests
|
||||
|
||||
|
||||
def filter_requests_for_dp(requests, data_parallel_size):
|
||||
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
|
||||
# works for external launcher mode. Should be cleaned up and deprecated
|
||||
# in the future with a better vLLM distributed process design.
|
||||
if data_parallel_size == 1:
|
||||
return requests
|
||||
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
data_parallel_rank = global_rank // (world_size // data_parallel_size)
|
||||
return [
|
||||
r
|
||||
for i, r in enumerate(requests)
|
||||
if i % data_parallel_size == data_parallel_rank
|
||||
]
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if (
|
||||
not args.dataset
|
||||
and not args.dataset_path
|
||||
and args.dataset_name not in {"prefix_repetition"}
|
||||
):
|
||||
print("When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = "random"
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
if args.input_len is None and random_input_len is None:
|
||||
raise ValueError(
|
||||
"Either --input-len or --random-input-len must be provided "
|
||||
"for a random dataset"
|
||||
)
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm-chat", (
|
||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
||||
)
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random',
|
||||
# 'random-mm', or 'random-rerank'
|
||||
if (
|
||||
args.dataset_name not in {"random", "random-mm", "random-rerank"}
|
||||
and args.random_range_ratio is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --random-batch-size: only used when dataset_name is 'random-rerank'
|
||||
if (
|
||||
args.dataset_name != "random-rerank"
|
||||
and getattr(args, "random_batch_size", None) is not None
|
||||
) and args.random_batch_size != 1:
|
||||
warnings.warn(
|
||||
"--random-batch-size will be ignored since \
|
||||
--dataset-name is not 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --no-reranker: only used when dataset_name is 'random-rerank'
|
||||
if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
|
||||
warnings.warn(
|
||||
"--no-reranker will be ignored since \
|
||||
--dataset-name is not 'random-rerank'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'random-mm',
|
||||
# 'sonnet', or not set.
|
||||
if (
|
||||
args.dataset_name not in {"random", "random-mm", "sonnet", None}
|
||||
and args.prefix_len is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'random-mm', 'sonnet', or not set.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === Random Dataset Argument Conflict Detection ===
|
||||
# Check for conflicts between regular and random arguments when using
|
||||
# random datasets
|
||||
if args.dataset_name in {"random", "random-mm", "random-rerank"}:
|
||||
random_input_len = getattr(args, "random_input_len", None)
|
||||
random_output_len = getattr(args, "random_output_len", None)
|
||||
random_prefix_len = getattr(args, "random_prefix_len", None)
|
||||
|
||||
if args.input_len is not None and random_input_len is not None:
|
||||
warnings.warn(
|
||||
"Both --input-len and --random-input-len are specified. "
|
||||
"The random version (--random-input-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if args.output_len is not None and random_output_len is not None:
|
||||
warnings.warn(
|
||||
"Both --output-len and --random-output-len are specified. "
|
||||
"The random version (--random-output-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if args.prefix_len is not None and random_prefix_len is not None:
|
||||
warnings.warn(
|
||||
"Both --prefix-len and --random-prefix-len are specified. "
|
||||
"The random version (--random-prefix-len) will be preferred "
|
||||
"in this run.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if (
|
||||
args.backend in {"hf", "mii"}
|
||||
and getattr(args, "quantization", None) is not None
|
||||
):
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
if args.data_parallel_size > 1 and (
|
||||
args.distributed_executor_backend != "external_launcher" or args.async_engine
|
||||
):
|
||||
# --data-parallel is not supported fully.
|
||||
# Old issue: https://github.com/vllm-project/vllm/issues/16222
|
||||
# Currently we only support data parallel with external launcher
|
||||
# mode (i.e., launch with toruchrun).
|
||||
raise ValueError(
|
||||
"Data parallel is only supported with external launcher mode "
|
||||
"with synchronous engine in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=[
|
||||
"sharegpt",
|
||||
"random",
|
||||
"sonnet",
|
||||
"burstgpt",
|
||||
"hf",
|
||||
"prefix_repetition",
|
||||
"random-mm",
|
||||
"random-rerank",
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-engine",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"
|
||||
),
|
||||
)
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before the random "
|
||||
"context in a request (default: 0).",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument(
|
||||
"--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM Profiling. --profiler-config must be provided on the server.",
|
||||
)
|
||||
|
||||
# prefix repetition dataset
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-suffix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of suffix tokens per request, used only for prefix "
|
||||
"repetition dataset. Total input length is prefix_len + suffix_len.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-num-prefixes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefixes to generate, used only for prefix repetition "
|
||||
"dataset. Prompts per prefix is num_requests // num_prefixes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-repetition-output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of output tokens per request, used only for prefix "
|
||||
"repetition dataset.",
|
||||
)
|
||||
|
||||
# (random, random-mm, random-rerank)
|
||||
add_random_dataset_base_args(parser)
|
||||
add_random_multimodal_dataset_args(parser)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
validate_args(args)
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
if (
|
||||
args.backend == "hf" or args.backend == "mii"
|
||||
) and args.tokenizer_mode == "auto":
|
||||
# mistral_common tokenizer is only supported on vllm and vllm-chat backends;
|
||||
# for hf and mii backends, we use hf tokenizer
|
||||
args.tokenizer_mode = "hf"
|
||||
tokenizer = get_tokenizer(
|
||||
args.tokenizer,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: list[RequestOutput] | None = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
disable_frontend_multiprocessing=args.disable_frontend_multiprocessing,
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
)
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
if args.profile:
|
||||
raise NotImplementedError("Profiling not implemented yet for backend='hf'.")
|
||||
elapsed_time = run_hf(
|
||||
requests,
|
||||
args.model,
|
||||
tokenizer,
|
||||
args.n,
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
disable_detokenize=args.disable_detokenize,
|
||||
do_profile=args.profile,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += (
|
||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
)
|
||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print(
|
||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
||||
)
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
||||
)
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
851
vllm/collect_env.py
Normal file
851
vllm/collect_env.py
Normal file
@@ -0,0 +1,851 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ruff: noqa
|
||||
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
|
||||
|
||||
import datetime
|
||||
import locale
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# Unlike the rest of the PyTorch this file must be python2 compliant.
|
||||
# This script outputs relevant system environment info
|
||||
# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
|
||||
from collections import namedtuple
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.envs import environment_variables
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
TORCH_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
# System Environment Information
|
||||
SystemEnv = namedtuple(
|
||||
"SystemEnv",
|
||||
[
|
||||
"torch_version",
|
||||
"is_debug_build",
|
||||
"cuda_compiled_version",
|
||||
"gcc_version",
|
||||
"clang_version",
|
||||
"cmake_version",
|
||||
"os",
|
||||
"libc_version",
|
||||
"python_version",
|
||||
"python_platform",
|
||||
"is_cuda_available",
|
||||
"cuda_runtime_version",
|
||||
"cuda_module_loading",
|
||||
"nvidia_driver_version",
|
||||
"nvidia_gpu_models",
|
||||
"cudnn_version",
|
||||
"pip_version", # 'pip' or 'pip3'
|
||||
"pip_packages",
|
||||
"conda_packages",
|
||||
"hip_compiled_version",
|
||||
"hip_runtime_version",
|
||||
"miopen_runtime_version",
|
||||
"caching_allocator_config",
|
||||
"is_xnnpack_available",
|
||||
"cpu_info",
|
||||
"rocm_version", # vllm specific field
|
||||
"vllm_version", # vllm specific field
|
||||
"vllm_build_flags", # vllm specific field
|
||||
"gpu_topo", # vllm specific field
|
||||
"env_vars",
|
||||
],
|
||||
)
|
||||
|
||||
DEFAULT_CONDA_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"cudatoolkit",
|
||||
"soumith",
|
||||
"mkl",
|
||||
"magma",
|
||||
"triton",
|
||||
"optree",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
"helion",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"mypy",
|
||||
"flake8",
|
||||
"triton",
|
||||
"optree",
|
||||
"onnx",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
"flashinfer-python",
|
||||
"helion",
|
||||
}
|
||||
|
||||
|
||||
def run(command):
|
||||
"""Return (return-code, stdout, stderr)."""
|
||||
shell = True if type(command) is str else False
|
||||
try:
|
||||
p = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell
|
||||
)
|
||||
raw_output, raw_err = p.communicate()
|
||||
rc = p.returncode
|
||||
if get_platform() == "win32":
|
||||
enc = "oem"
|
||||
else:
|
||||
enc = locale.getpreferredencoding()
|
||||
output = raw_output.decode(enc)
|
||||
if command == "nvidia-smi topo -m":
|
||||
# don't remove the leading whitespace of `nvidia-smi topo -m`
|
||||
# because they are meaningful
|
||||
output = output.rstrip()
|
||||
else:
|
||||
output = output.strip()
|
||||
err = raw_err.decode(enc)
|
||||
return rc, output, err.strip()
|
||||
|
||||
except FileNotFoundError:
|
||||
cmd_str = command if isinstance(command, str) else command[0]
|
||||
return 127, "", f"Command not found: {cmd_str}"
|
||||
|
||||
|
||||
def run_and_read_all(run_lambda, command):
|
||||
"""Run command using run_lambda; reads and returns entire output if rc is 0."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out
|
||||
|
||||
|
||||
def run_and_parse_first_match(run_lambda, command, regex):
|
||||
"""Run command using run_lambda, returns the first regex match if it exists."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
match = re.search(regex, out)
|
||||
if match is None:
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def get_conda_packages(run_lambda, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_CONDA_PATTERNS
|
||||
conda = os.environ.get("CONDA_EXE", "conda")
|
||||
out = run_and_read_all(run_lambda, [conda, "list"])
|
||||
if out is None:
|
||||
return out
|
||||
|
||||
return "\n".join(
|
||||
line
|
||||
for line in out.splitlines()
|
||||
if not line.startswith("#") and any(name in line for name in patterns)
|
||||
)
|
||||
|
||||
|
||||
def get_gcc_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
|
||||
|
||||
|
||||
def get_clang_version(run_lambda):
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "clang --version", r"clang version (.*)"
|
||||
)
|
||||
|
||||
|
||||
def get_cmake_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
|
||||
|
||||
|
||||
def get_nvidia_driver_version(run_lambda):
|
||||
if get_platform() == "darwin":
|
||||
cmd = "kextstat | grep -i cuda"
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]"
|
||||
)
|
||||
smi = get_nvidia_smi()
|
||||
return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
|
||||
|
||||
|
||||
def get_gpu_info(run_lambda):
|
||||
if get_platform() == "darwin" or (
|
||||
TORCH_AVAILABLE
|
||||
and hasattr(torch.version, "hip")
|
||||
and torch.version.hip is not None
|
||||
):
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
if torch.version.hip is not None:
|
||||
prop = torch.cuda.get_device_properties(0)
|
||||
if hasattr(prop, "gcnArchName"):
|
||||
gcnArch = " ({})".format(prop.gcnArchName)
|
||||
else:
|
||||
gcnArch = "NoGCNArchNameOnOldPyTorch"
|
||||
else:
|
||||
gcnArch = ""
|
||||
return torch.cuda.get_device_name(None) + gcnArch
|
||||
return None
|
||||
smi = get_nvidia_smi()
|
||||
uuid_regex = re.compile(r" \(UUID: .+?\)")
|
||||
rc, out, _ = run_lambda(smi + " -L")
|
||||
if rc != 0:
|
||||
return None
|
||||
# Anonymize GPUs by removing their UUID
|
||||
return re.sub(uuid_regex, "", out)
|
||||
|
||||
|
||||
def get_running_cuda_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
|
||||
|
||||
|
||||
def get_cudnn_version(run_lambda):
|
||||
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
|
||||
where_cmd = os.path.join(system_root, "System32", "where")
|
||||
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
|
||||
elif get_platform() == "darwin":
|
||||
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
|
||||
# https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
|
||||
# https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
|
||||
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
|
||||
cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
|
||||
else:
|
||||
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
|
||||
rc, out, _ = run_lambda(cudnn_cmd)
|
||||
# find will return 1 if there are permission errors or if not found
|
||||
if len(out) == 0 or (rc != 1 and rc != 0):
|
||||
l = os.environ.get("CUDNN_LIBRARY")
|
||||
if l is not None and os.path.isfile(l):
|
||||
return os.path.realpath(l)
|
||||
return None
|
||||
files_set = set()
|
||||
for fn in out.split("\n"):
|
||||
fn = os.path.realpath(fn) # eliminate symbolic links
|
||||
if os.path.isfile(fn):
|
||||
files_set.add(fn)
|
||||
if not files_set:
|
||||
return None
|
||||
# Alphabetize the result because the order is non-deterministic otherwise
|
||||
files = sorted(files_set)
|
||||
if len(files) == 1:
|
||||
return files[0]
|
||||
result = "\n".join(files)
|
||||
return "Probably one of the following:\n{}".format(result)
|
||||
|
||||
|
||||
def get_nvidia_smi():
|
||||
# Note: nvidia-smi is currently available only on Windows and Linux
|
||||
smi = "nvidia-smi"
|
||||
if get_platform() == "win32":
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
|
||||
legacy_path = os.path.join(
|
||||
program_files_root, "NVIDIA Corporation", "NVSMI", smi
|
||||
)
|
||||
new_path = os.path.join(system_root, "System32", smi)
|
||||
smis = [new_path, legacy_path]
|
||||
for candidate_smi in smis:
|
||||
if os.path.exists(candidate_smi):
|
||||
smi = '"{}"'.format(candidate_smi)
|
||||
break
|
||||
return smi
|
||||
|
||||
|
||||
def get_rocm_version(run_lambda):
|
||||
"""Returns the ROCm version if available, otherwise 'N/A'."""
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "hipcc --version", r"HIP version: (\S+)"
|
||||
)
|
||||
|
||||
|
||||
def get_vllm_version():
|
||||
from vllm import __version__, __version_tuple__
|
||||
|
||||
if __version__ == "dev":
|
||||
return "N/A (dev)"
|
||||
version_str = __version_tuple__[-1]
|
||||
if isinstance(version_str, str) and version_str.startswith("g"):
|
||||
# it's a dev build
|
||||
if "." in version_str:
|
||||
# it's a dev build containing local changes
|
||||
git_sha = version_str.split(".")[0][1:]
|
||||
date = version_str.split(".")[-1][1:]
|
||||
return f"{__version__} (git sha: {git_sha}, date: {date})"
|
||||
else:
|
||||
# it's a dev build without local changes
|
||||
git_sha = version_str[1:] # type: ignore
|
||||
return f"{__version__} (git sha: {git_sha})"
|
||||
return __version__
|
||||
|
||||
|
||||
def summarize_vllm_build_flags():
|
||||
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
||||
return "CUDA Archs: {}; ROCm: {}".format(
|
||||
os.environ.get("TORCH_CUDA_ARCH_LIST", "Not Set"),
|
||||
"Enabled" if os.environ.get("ROCM_HOME") else "Disabled",
|
||||
)
|
||||
|
||||
|
||||
def get_gpu_topo(run_lambda):
|
||||
output = None
|
||||
|
||||
if get_platform() == "linux":
|
||||
output = run_and_read_all(run_lambda, "nvidia-smi topo -m")
|
||||
if output is None:
|
||||
output = run_and_read_all(run_lambda, "rocm-smi --showtopo")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# example outputs of CPU infos
|
||||
# * linux
|
||||
# Architecture: x86_64
|
||||
# CPU op-mode(s): 32-bit, 64-bit
|
||||
# Address sizes: 46 bits physical, 48 bits virtual
|
||||
# Byte Order: Little Endian
|
||||
# CPU(s): 128
|
||||
# On-line CPU(s) list: 0-127
|
||||
# Vendor ID: GenuineIntel
|
||||
# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# CPU family: 6
|
||||
# Model: 106
|
||||
# Thread(s) per core: 2
|
||||
# Core(s) per socket: 32
|
||||
# Socket(s): 2
|
||||
# Stepping: 6
|
||||
# BogoMIPS: 5799.78
|
||||
# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
|
||||
# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
|
||||
# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
|
||||
# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
|
||||
# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
|
||||
# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
|
||||
# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
|
||||
# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
|
||||
# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
|
||||
# Virtualization features:
|
||||
# Hypervisor vendor: KVM
|
||||
# Virtualization type: full
|
||||
# Caches (sum of all):
|
||||
# L1d: 3 MiB (64 instances)
|
||||
# L1i: 2 MiB (64 instances)
|
||||
# L2: 80 MiB (64 instances)
|
||||
# L3: 108 MiB (2 instances)
|
||||
# NUMA:
|
||||
# NUMA node(s): 2
|
||||
# NUMA node0 CPU(s): 0-31,64-95
|
||||
# NUMA node1 CPU(s): 32-63,96-127
|
||||
# Vulnerabilities:
|
||||
# Itlb multihit: Not affected
|
||||
# L1tf: Not affected
|
||||
# Mds: Not affected
|
||||
# Meltdown: Not affected
|
||||
# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
|
||||
# Retbleed: Not affected
|
||||
# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
|
||||
# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
|
||||
# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
|
||||
# Srbds: Not affected
|
||||
# Tsx async abort: Not affected
|
||||
# * win32
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU0
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
#
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU1
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
|
||||
|
||||
def get_cpu_info(run_lambda):
|
||||
rc, out, err = 0, "", ""
|
||||
if get_platform() == "linux":
|
||||
rc, out, err = run_lambda("lscpu")
|
||||
elif get_platform() == "win32":
|
||||
rc, out, err = run_lambda(
|
||||
"wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
|
||||
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE"
|
||||
)
|
||||
elif get_platform() == "darwin":
|
||||
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
||||
cpu_info = "None"
|
||||
if rc == 0:
|
||||
cpu_info = out
|
||||
else:
|
||||
cpu_info = err
|
||||
return cpu_info
|
||||
|
||||
|
||||
def get_platform():
|
||||
if sys.platform.startswith("linux"):
|
||||
return "linux"
|
||||
elif sys.platform.startswith("win32"):
|
||||
return "win32"
|
||||
elif sys.platform.startswith("cygwin"):
|
||||
return "cygwin"
|
||||
elif sys.platform.startswith("darwin"):
|
||||
return "darwin"
|
||||
else:
|
||||
return sys.platform
|
||||
|
||||
|
||||
def get_mac_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
|
||||
|
||||
|
||||
def get_windows_version(run_lambda):
|
||||
system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
|
||||
wmic_cmd = os.path.join(system_root, "System32", "Wbem", "wmic")
|
||||
findstr_cmd = os.path.join(system_root, "System32", "findstr")
|
||||
return run_and_read_all(
|
||||
run_lambda, "{} os get Caption | {} /v Caption".format(wmic_cmd, findstr_cmd)
|
||||
)
|
||||
|
||||
|
||||
def get_lsb_version(run_lambda):
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "lsb_release -a", r"Description:\t(.*)"
|
||||
)
|
||||
|
||||
|
||||
def check_release_file(run_lambda):
|
||||
return run_and_parse_first_match(
|
||||
run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"'
|
||||
)
|
||||
|
||||
|
||||
def get_os(run_lambda):
|
||||
from platform import machine
|
||||
|
||||
platform = get_platform()
|
||||
|
||||
if platform == "win32" or platform == "cygwin":
|
||||
return get_windows_version(run_lambda)
|
||||
|
||||
if platform == "darwin":
|
||||
version = get_mac_version(run_lambda)
|
||||
if version is None:
|
||||
return None
|
||||
return "macOS {} ({})".format(version, machine())
|
||||
|
||||
if platform == "linux":
|
||||
# Ubuntu/Debian based
|
||||
desc = get_lsb_version(run_lambda)
|
||||
if desc is not None:
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
# Try reading /etc/*-release
|
||||
desc = check_release_file(run_lambda)
|
||||
if desc is not None:
|
||||
return "{} ({})".format(desc, machine())
|
||||
|
||||
return "{} ({})".format(platform, machine())
|
||||
|
||||
# Unknown platform
|
||||
return platform
|
||||
|
||||
|
||||
def get_python_platform():
|
||||
import platform
|
||||
|
||||
return platform.platform()
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
import platform
|
||||
|
||||
if get_platform() != "linux":
|
||||
return "N/A"
|
||||
return "-".join(platform.libc_ver())
|
||||
|
||||
|
||||
def is_uv_venv():
|
||||
if os.environ.get("UV"):
|
||||
return True
|
||||
pyvenv_cfg_path = os.path.join(sys.prefix, "pyvenv.cfg")
|
||||
if os.path.exists(pyvenv_cfg_path):
|
||||
with open(pyvenv_cfg_path, "r") as f:
|
||||
return any(line.startswith("uv = ") for line in f)
|
||||
return False
|
||||
|
||||
|
||||
def get_pip_packages(run_lambda, patterns=None):
|
||||
"""Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_PIP_PATTERNS
|
||||
|
||||
def run_with_pip():
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
pip_spec = importlib.util.find_spec("pip")
|
||||
pip_available = pip_spec is not None
|
||||
except ImportError:
|
||||
pip_available = False
|
||||
|
||||
if pip_available:
|
||||
cmd = [sys.executable, "-mpip", "list", "--format=freeze"]
|
||||
elif is_uv_venv():
|
||||
print("uv is set")
|
||||
cmd = ["uv", "pip", "list", "--format=freeze"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Could not collect pip list output (pip or uv module not available)"
|
||||
)
|
||||
|
||||
out = run_and_read_all(run_lambda, cmd)
|
||||
return "\n".join(
|
||||
line for line in out.splitlines() if any(name in line for name in patterns)
|
||||
)
|
||||
|
||||
pip_version = "pip3" if sys.version[0] == "3" else "pip"
|
||||
out = run_with_pip()
|
||||
return pip_version, out
|
||||
|
||||
|
||||
def get_cachingallocator_config():
|
||||
ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
||||
return ca_config
|
||||
|
||||
|
||||
def get_cuda_module_loading_config():
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
config = os.environ.get("CUDA_MODULE_LOADING", "")
|
||||
return config
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def is_xnnpack_available():
|
||||
if TORCH_AVAILABLE:
|
||||
import torch.backends.xnnpack
|
||||
|
||||
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_env_vars():
|
||||
env_vars = ""
|
||||
secret_terms = ("secret", "token", "api", "access", "password")
|
||||
report_prefix = (
|
||||
"TORCH",
|
||||
"NCCL",
|
||||
"PYTORCH",
|
||||
"CUDA",
|
||||
"CUBLAS",
|
||||
"CUDNN",
|
||||
"OMP_",
|
||||
"MKL_",
|
||||
"NVIDIA",
|
||||
)
|
||||
for k, v in os.environ.items():
|
||||
if any(term in k.lower() for term in secret_terms):
|
||||
continue
|
||||
if k in environment_variables:
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
if k.startswith(report_prefix):
|
||||
env_vars = env_vars + "{}={}".format(k, v) + "\n"
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def get_env_info():
|
||||
run_lambda = run
|
||||
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
version_str = torch.__version__
|
||||
debug_mode_str = str(torch.version.debug)
|
||||
cuda_available_str = str(torch.cuda.is_available())
|
||||
cuda_version_str = torch.version.cuda
|
||||
if (
|
||||
not hasattr(torch.version, "hip") or torch.version.hip is None
|
||||
): # cuda version
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
|
||||
else: # HIP version
|
||||
|
||||
def get_version_or_na(cfg, prefix):
|
||||
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
|
||||
return _lst[0] if _lst else "N/A"
|
||||
|
||||
cfg = torch._C._show_config().split("\n")
|
||||
hip_runtime_version = get_version_or_na(cfg, "HIP Runtime")
|
||||
miopen_runtime_version = get_version_or_na(cfg, "MIOpen")
|
||||
cuda_version_str = "N/A"
|
||||
hip_compiled_version = torch.version.hip
|
||||
else:
|
||||
version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A"
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
|
||||
|
||||
sys_version = sys.version.replace("\n", " ")
|
||||
|
||||
conda_packages = get_conda_packages(run_lambda)
|
||||
|
||||
rocm_version = get_rocm_version(run_lambda)
|
||||
vllm_version = get_vllm_version()
|
||||
vllm_build_flags = summarize_vllm_build_flags()
|
||||
gpu_topo = get_gpu_topo(run_lambda)
|
||||
|
||||
return SystemEnv(
|
||||
torch_version=version_str,
|
||||
is_debug_build=debug_mode_str,
|
||||
python_version="{} ({}-bit runtime)".format(
|
||||
sys_version, sys.maxsize.bit_length() + 1
|
||||
),
|
||||
python_platform=get_python_platform(),
|
||||
is_cuda_available=cuda_available_str,
|
||||
cuda_compiled_version=cuda_version_str,
|
||||
cuda_runtime_version=get_running_cuda_version(run_lambda),
|
||||
cuda_module_loading=get_cuda_module_loading_config(),
|
||||
nvidia_gpu_models=get_gpu_info(run_lambda),
|
||||
nvidia_driver_version=get_nvidia_driver_version(run_lambda),
|
||||
cudnn_version=get_cudnn_version(run_lambda),
|
||||
hip_compiled_version=hip_compiled_version,
|
||||
hip_runtime_version=hip_runtime_version,
|
||||
miopen_runtime_version=miopen_runtime_version,
|
||||
pip_version=pip_version,
|
||||
pip_packages=pip_list_output,
|
||||
conda_packages=conda_packages,
|
||||
os=get_os(run_lambda),
|
||||
libc_version=get_libc_version(),
|
||||
gcc_version=get_gcc_version(run_lambda),
|
||||
clang_version=get_clang_version(run_lambda),
|
||||
cmake_version=get_cmake_version(run_lambda),
|
||||
caching_allocator_config=get_cachingallocator_config(),
|
||||
is_xnnpack_available=is_xnnpack_available(),
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
rocm_version=rocm_version,
|
||||
vllm_version=vllm_version,
|
||||
vllm_build_flags=vllm_build_flags,
|
||||
gpu_topo=gpu_topo,
|
||||
env_vars=get_env_vars(),
|
||||
)
|
||||
|
||||
|
||||
env_info_fmt = """
|
||||
==============================
|
||||
System Info
|
||||
==============================
|
||||
OS : {os}
|
||||
GCC version : {gcc_version}
|
||||
Clang version : {clang_version}
|
||||
CMake version : {cmake_version}
|
||||
Libc version : {libc_version}
|
||||
|
||||
==============================
|
||||
PyTorch Info
|
||||
==============================
|
||||
PyTorch version : {torch_version}
|
||||
Is debug build : {is_debug_build}
|
||||
CUDA used to build PyTorch : {cuda_compiled_version}
|
||||
ROCM used to build PyTorch : {hip_compiled_version}
|
||||
|
||||
==============================
|
||||
Python Environment
|
||||
==============================
|
||||
Python version : {python_version}
|
||||
Python platform : {python_platform}
|
||||
|
||||
==============================
|
||||
CUDA / GPU Info
|
||||
==============================
|
||||
Is CUDA available : {is_cuda_available}
|
||||
CUDA runtime version : {cuda_runtime_version}
|
||||
CUDA_MODULE_LOADING set to : {cuda_module_loading}
|
||||
GPU models and configuration : {nvidia_gpu_models}
|
||||
Nvidia driver version : {nvidia_driver_version}
|
||||
cuDNN version : {cudnn_version}
|
||||
HIP runtime version : {hip_runtime_version}
|
||||
MIOpen runtime version : {miopen_runtime_version}
|
||||
Is XNNPACK available : {is_xnnpack_available}
|
||||
|
||||
==============================
|
||||
CPU Info
|
||||
==============================
|
||||
{cpu_info}
|
||||
|
||||
==============================
|
||||
Versions of relevant libraries
|
||||
==============================
|
||||
{pip_packages}
|
||||
{conda_packages}
|
||||
""".strip()
|
||||
|
||||
# both the above code and the following code use `strip()` to
|
||||
# remove leading/trailing whitespaces, so we need to add a newline
|
||||
# in between to separate the two sections
|
||||
env_info_fmt += "\n\n"
|
||||
|
||||
env_info_fmt += """
|
||||
==============================
|
||||
vLLM Info
|
||||
==============================
|
||||
ROCM Version : {rocm_version}
|
||||
vLLM Version : {vllm_version}
|
||||
vLLM Build Flags:
|
||||
{vllm_build_flags}
|
||||
GPU Topology:
|
||||
{gpu_topo}
|
||||
|
||||
==============================
|
||||
Environment Variables
|
||||
==============================
|
||||
{env_vars}
|
||||
""".strip()
|
||||
|
||||
|
||||
def pretty_str(envinfo):
|
||||
def replace_nones(dct, replacement="Could not collect"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is not None:
|
||||
continue
|
||||
dct[key] = replacement
|
||||
return dct
|
||||
|
||||
def replace_bools(dct, true="Yes", false="No"):
|
||||
for key in dct.keys():
|
||||
if dct[key] is True:
|
||||
dct[key] = true
|
||||
elif dct[key] is False:
|
||||
dct[key] = false
|
||||
return dct
|
||||
|
||||
def prepend(text, tag="[prepend]"):
|
||||
lines = text.split("\n")
|
||||
updated_lines = [tag + line for line in lines]
|
||||
return "\n".join(updated_lines)
|
||||
|
||||
def replace_if_empty(text, replacement="No relevant packages"):
|
||||
if text is not None and len(text) == 0:
|
||||
return replacement
|
||||
return text
|
||||
|
||||
def maybe_start_on_next_line(string):
|
||||
# If `string` is multiline, prepend a \n to it.
|
||||
if string is not None and len(string.split("\n")) > 1:
|
||||
return "\n{}\n".format(string)
|
||||
return string
|
||||
|
||||
mutable_dict = envinfo._asdict()
|
||||
|
||||
# If nvidia_gpu_models is multiline, start on the next line
|
||||
mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(
|
||||
envinfo.nvidia_gpu_models
|
||||
)
|
||||
|
||||
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
|
||||
dynamic_cuda_fields = [
|
||||
"cuda_runtime_version",
|
||||
"nvidia_gpu_models",
|
||||
"nvidia_driver_version",
|
||||
]
|
||||
all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
|
||||
all_dynamic_cuda_fields_missing = all(
|
||||
mutable_dict[field] is None for field in dynamic_cuda_fields
|
||||
)
|
||||
if (
|
||||
TORCH_AVAILABLE
|
||||
and not torch.cuda.is_available()
|
||||
and all_dynamic_cuda_fields_missing
|
||||
):
|
||||
for field in all_cuda_fields:
|
||||
mutable_dict[field] = "No CUDA"
|
||||
if envinfo.cuda_compiled_version is None:
|
||||
mutable_dict["cuda_compiled_version"] = "None"
|
||||
|
||||
# Replace True with Yes, False with No
|
||||
mutable_dict = replace_bools(mutable_dict)
|
||||
|
||||
# Replace all None objects with 'Could not collect'
|
||||
mutable_dict = replace_nones(mutable_dict)
|
||||
|
||||
# If either of these are '', replace with 'No relevant packages'
|
||||
mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
|
||||
mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
|
||||
|
||||
# Tag conda and pip packages with a prefix
|
||||
# If they were previously None, they'll show up as ie '[conda] Could not collect'
|
||||
if mutable_dict["pip_packages"]:
|
||||
mutable_dict["pip_packages"] = prepend(
|
||||
mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)
|
||||
)
|
||||
if mutable_dict["conda_packages"]:
|
||||
mutable_dict["conda_packages"] = prepend(
|
||||
mutable_dict["conda_packages"], "[conda] "
|
||||
)
|
||||
mutable_dict["cpu_info"] = envinfo.cpu_info
|
||||
return env_info_fmt.format(**mutable_dict)
|
||||
|
||||
|
||||
def get_pretty_env_info():
|
||||
return pretty_str(get_env_info())
|
||||
|
||||
|
||||
def main():
|
||||
print("Collecting environment information...")
|
||||
output = get_pretty_env_info()
|
||||
print(output)
|
||||
|
||||
if (
|
||||
TORCH_AVAILABLE
|
||||
and hasattr(torch, "utils")
|
||||
and hasattr(torch.utils, "_crash_handler")
|
||||
):
|
||||
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
|
||||
if sys.platform == "linux" and os.path.exists(minidump_dir):
|
||||
dumps = [
|
||||
os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)
|
||||
]
|
||||
latest = max(dumps, key=os.path.getctime)
|
||||
ctime = os.path.getctime(latest)
|
||||
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
msg = (
|
||||
"\n*** Detected a minidump at {} created on {}, ".format(
|
||||
latest, creation_time
|
||||
)
|
||||
+ "if this is related to your bug please include it when you file a report ***"
|
||||
)
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
vllm/compilation/__init__.py
Normal file
0
vllm/compilation/__init__.py
Normal file
1131
vllm/compilation/backends.py
Normal file
1131
vllm/compilation/backends.py
Normal file
File diff suppressed because it is too large
Load Diff
57
vllm/compilation/base_static_graph.py
Normal file
57
vllm/compilation/base_static_graph.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Protocol
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
|
||||
|
||||
class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
StaticGraphWrapper interface that allows platforms to wrap a callable
|
||||
to be captured as a static graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
|
||||
Args:
|
||||
runnable (Callable): The callable to be wrapped and captured.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
runtime_mode (CUDAGraphMode): The style of the static
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Executes the wrapped callable.
|
||||
|
||||
If the current runtime mode in the ForwardContext matches the runtime
|
||||
mode of this instance, it replays the CUDAGraph or captures it using
|
||||
the callable if it hasn't been captured yet. Otherwise, it calls the
|
||||
original callable directly.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
callable.
|
||||
**kwargs: Keyword arguments to be passed into the callable.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed callable.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
516
vllm/compilation/caching.py
Normal file
516
vllm/compilation/caching.py
Normal file
@@ -0,0 +1,516 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.compiler_interface import get_inductor_factors
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
try:
|
||||
from torch._dynamo.aot_compile import SerializableCallable
|
||||
except ImportError:
|
||||
SerializableCallable = object
|
||||
|
||||
assert isinstance(SerializableCallable, type)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StandaloneCompiledArtifacts:
|
||||
"""Storage for standalone compiled artifacts with content-based deduplication.
|
||||
|
||||
Deduplication works via a two-level indirection:
|
||||
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
|
||||
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
|
||||
|
||||
When inserting, we compute the SHA256 hash of the bytes. If the hash
|
||||
already exists in `submodule_bytes_store`, we reuse the existing entry
|
||||
rather than storing duplicate bytes. This is common because submodules
|
||||
often compile to identical artifacts (e.g., identical transformer layers
|
||||
split on attn)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# dict from submodule name to byte hash
|
||||
self.submodule_bytes: dict[str, str] = {}
|
||||
# dict from byte hash to bytes
|
||||
self.submodule_bytes_store: dict[str, bytes] = {}
|
||||
# dict from byte hash to loaded module
|
||||
self.loaded_submodule_store: dict[str, Any] = {}
|
||||
|
||||
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(entry)
|
||||
hex_digest = hasher.hexdigest()
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
|
||||
if hex_digest not in self.submodule_bytes_store:
|
||||
self.submodule_bytes_store[hex_digest] = entry
|
||||
logger.debug(
|
||||
"inserting new artifact for submod %s with shape %s "
|
||||
"(%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"reusing existing cache artifact for submod %s "
|
||||
"with shape %s (%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
|
||||
def get(self, submod_name: str, shape: str) -> bytes:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.submodule_bytes_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def get_loaded(self, submod_name: str, shape: str) -> Any:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.loaded_submodule_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def size_bytes(self) -> int:
|
||||
return sum(len(entry) for entry in self.submodule_bytes_store.values())
|
||||
|
||||
def num_artifacts(self) -> int:
|
||||
return len(self.submodule_bytes_store)
|
||||
|
||||
def num_entries(self) -> int:
|
||||
return len(self.submodule_bytes)
|
||||
|
||||
def submodule_names(self) -> list[str]:
|
||||
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
|
||||
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
|
||||
return list(dict.fromkeys(names))
|
||||
|
||||
def load_all(self) -> None:
|
||||
import concurrent.futures
|
||||
|
||||
# check already loaded
|
||||
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
|
||||
return
|
||||
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
|
||||
entry = pickle.loads(entry_bytes)
|
||||
return AOTCompiledArtifact.deserialize(entry)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
entries = list(self.submodule_bytes_store.values())
|
||||
loaded_entries = list(executor.map(_load_entry, entries))
|
||||
|
||||
for i, k in enumerate(self.submodule_bytes_store.keys()):
|
||||
self.loaded_submodule_store[k] = loaded_entries[i]
|
||||
|
||||
logger.debug("loaded all %s submodules", self.num_artifacts())
|
||||
|
||||
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
|
||||
return {
|
||||
"submodule_bytes": self.submodule_bytes,
|
||||
"submodule_bytes_store": self.submodule_bytes_store,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
|
||||
self.submodule_bytes = state["submodule_bytes"]
|
||||
self.submodule_bytes_store = state["submodule_bytes_store"]
|
||||
self.loaded_submodule_store = {}
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
inputs to the compiled function and return the result.
|
||||
It also implements a serialization interface to support PyTorch's precompile
|
||||
with custom backend, so that we can save and load the compiled function on
|
||||
disk. There's no need to wrap around the compiled function if we don't want
|
||||
to serialize them in particular cases.
|
||||
Right now serialization for the custom backend is done via
|
||||
serializing the Dynamo fx graph plus example inputs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_module: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[Any],
|
||||
prefix: str,
|
||||
optimized_call: Callable[..., Any],
|
||||
is_encoder: bool = False,
|
||||
vllm_backend: Any | None = None,
|
||||
sym_tensor_indices: list[int] | None = None,
|
||||
) -> None:
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
self.graph_module = graph_module
|
||||
self.example_inputs = example_inputs
|
||||
self.prefix = prefix
|
||||
self.optimized_call = optimized_call
|
||||
self.is_encoder = is_encoder
|
||||
self.shape_env = None
|
||||
self.vllm_backend = vllm_backend
|
||||
self.sym_tensor_indices = sym_tensor_indices
|
||||
sym_input = next(
|
||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||
)
|
||||
if sym_input is not None:
|
||||
self.shape_env = sym_input.node.shape_env
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.optimized_call(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, compiled_fn: "VllmSerializableFunction"
|
||||
) -> bytes:
|
||||
import sympy
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler, Options
|
||||
|
||||
state = compiled_fn.__dict__.copy()
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
state.pop("vllm_backend", None)
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
for name, submod in state["graph_module"].named_children():
|
||||
if hasattr(submod, "graph"):
|
||||
for node in submod.graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
|
||||
graph_reducer_override = GraphPickler.reducer_override
|
||||
|
||||
def _graph_reducer_override(
|
||||
self: GraphPickler, obj: Any
|
||||
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, sympy.Function)
|
||||
and hasattr(obj, "_torch_unpickler")
|
||||
):
|
||||
return obj._torch_unpickler, (obj._torch_handler_name,)
|
||||
if isinstance(obj, FakeTensorMode):
|
||||
return type(None), ()
|
||||
return graph_reducer_override(self, obj)
|
||||
|
||||
if state.get("sym_tensor_indices"):
|
||||
# put tensor inputs on meta device since their data
|
||||
# isn't needed, yet we need the meta for make_copy_and_call
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
else:
|
||||
# mask off all tensor inputs since they are large and not needed.
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||
|
||||
if compiled_fn.vllm_backend:
|
||||
(
|
||||
standalone_compile_artifacts,
|
||||
sym_shape_indices_map,
|
||||
returns_tuple_map,
|
||||
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
|
||||
state["standalone_compile_artifacts"] = standalone_compile_artifacts
|
||||
state["sym_shape_indices_map"] = sym_shape_indices_map
|
||||
state["returns_tuple_map"] = returns_tuple_map
|
||||
return pickle.dumps(state)
|
||||
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
|
||||
from torch._guards import TracingContext, tracing
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
state["graph_module"].recompile()
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
|
||||
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
|
||||
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
|
||||
returns_tuple_map = state.pop("returns_tuple_map", {})
|
||||
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
assert standalone_compile_artifacts is not None
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
num_submods = len(submod_names)
|
||||
num_artifacts = standalone_compile_artifacts.num_artifacts()
|
||||
|
||||
logger.info(
|
||||
"reconstructing serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
standalone_compile_artifacts=standalone_compile_artifacts,
|
||||
vllm_config=get_current_vllm_config(),
|
||||
sym_shape_indices_map=sym_shape_indices_map,
|
||||
returns_tuple_map=returns_tuple_map,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"reconstructed serializable fn from standalone compile artifacts"
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
# Fall back to standard VllmBackend
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
vllm_backend: VllmBackend = VllmBackend(
|
||||
get_current_vllm_config(), state["prefix"], is_encoder
|
||||
)
|
||||
|
||||
def optimized_call(*example_inputs: Any) -> Any:
|
||||
"""
|
||||
On the first run of the optimized call, we rerun the compiler
|
||||
backend which should result in a cache hit. After the backend
|
||||
call returns, we just do a one-time replacement of the optimized
|
||||
call with the compiled function, so that subsequent calls are on
|
||||
the AOT compiled path.
|
||||
"""
|
||||
compile_inputs = [
|
||||
inp if inp is not None else example_inputs[i]
|
||||
for i, inp in enumerate(fn.example_inputs)
|
||||
]
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
fn.optimized_call = vllm_backend(
|
||||
state["graph_module"], compile_inputs
|
||||
).optimized_call
|
||||
return fn.optimized_call(*example_inputs)
|
||||
|
||||
fn = cls(**state, optimized_call=optimized_call)
|
||||
return fn
|
||||
|
||||
@property
|
||||
def co_name(self) -> Literal["VllmSerializableFunction"]:
|
||||
"""
|
||||
Used for depyf debugging.
|
||||
"""
|
||||
return "VllmSerializableFunction"
|
||||
|
||||
|
||||
def reconstruct_serializable_fn_from_mega_artifact(
|
||||
state: dict[str, Any],
|
||||
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
|
||||
vllm_config: VllmConfig,
|
||||
sym_shape_indices_map: dict[str, list[int]],
|
||||
returns_tuple_map: dict[str, bool],
|
||||
) -> "VllmSerializableFunction":
|
||||
"""Construct a VllmSerializableFunction from cached inductor artifacts.
|
||||
|
||||
This function reconstructs a callable model from pre-compiled inductor
|
||||
artifacts without re-running the compilation. It:
|
||||
1. Loads all cached artifacts
|
||||
2. Builds compiled callables for each submodule/shape
|
||||
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
|
||||
4. Wraps with cudagraph if needed
|
||||
5. Returns the final VllmSerializableFunction
|
||||
|
||||
Note: This function shares similar logic with PiecewiseCompileInterpreter
|
||||
in backends.py. Both create PiecewiseBackend instances and wrap them with
|
||||
cudagraph. The key difference is:
|
||||
- this function: PiecewiseBackend receives pre-compiled runnables
|
||||
(compiled_runnables is set, graph is None)
|
||||
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
|
||||
to compile (graph is set, compiled_runnables is None)
|
||||
|
||||
If modifying the backend creation/wrapping logic, consider updating both.
|
||||
|
||||
Args:
|
||||
state: Deserialized state dict containing graph_module, example_inputs,
|
||||
prefix, sym_tensor_indices, is_encoder, etc.
|
||||
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
|
||||
pre-compiled artifacts for each submodule/shape combination.
|
||||
vllm_config: The vLLM configuration.
|
||||
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
|
||||
returns_tuple_map: Mapping from submod_name to returns_tuple.
|
||||
|
||||
Returns:
|
||||
A VllmSerializableFunction that can be called directly.
|
||||
"""
|
||||
from vllm.compilation.backends import (
|
||||
VllmBackend,
|
||||
make_copy_and_call,
|
||||
wrap_with_cudagraph_if_needed,
|
||||
)
|
||||
from vllm.compilation.piecewise_backend import PiecewiseBackend
|
||||
|
||||
prefix = state["prefix"]
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
split_gm = state["graph_module"]
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
standalone_compile_artifacts.load_all()
|
||||
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
|
||||
|
||||
for cache_key in standalone_compile_artifacts.submodule_bytes:
|
||||
submod_name, shape_str = cache_key.rsplit("_", 1)
|
||||
compiled_callables.setdefault(submod_name, {})[shape_str] = (
|
||||
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
|
||||
)
|
||||
|
||||
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
|
||||
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
|
||||
os.makedirs(dummy_cache_dir, exist_ok=True)
|
||||
vllm_backend.compiler_manager.initialize_cache(
|
||||
cache_dir=dummy_cache_dir,
|
||||
disable_cache=True,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
# spot check that cached submodules exist in the graph structure
|
||||
graph_children = {name for name, _ in split_gm.named_children()}
|
||||
missing = set(submod_names) - graph_children
|
||||
assert not missing, (
|
||||
f"artifacts reference submodules not in graph: {missing}. "
|
||||
f"graph has: {sorted(graph_children)}"
|
||||
)
|
||||
|
||||
for i, submod_name in enumerate(submod_names):
|
||||
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
|
||||
|
||||
sym_shape_indices = sym_shape_indices_map[submod_name]
|
||||
returns_tuple = returns_tuple_map[submod_name]
|
||||
runnables = compiled_callables[submod_name]
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
graph=None, # not needed for cached artifacts
|
||||
vllm_config=vllm_config,
|
||||
piecewise_compile_index=i,
|
||||
total_piecewise_compiles=len(submod_names),
|
||||
sym_shape_indices=sym_shape_indices,
|
||||
vllm_backend=vllm_backend,
|
||||
returns_tuple=returns_tuple,
|
||||
compiled_runnables=runnables,
|
||||
)
|
||||
|
||||
is_first = i == 0
|
||||
is_last = i == len(submod_names) - 1
|
||||
wrapped_backend = wrap_with_cudagraph_if_needed(
|
||||
piecewise_backend,
|
||||
vllm_config,
|
||||
compilation_config,
|
||||
is_first,
|
||||
is_last,
|
||||
)
|
||||
|
||||
split_gm.__dict__[submod_name] = wrapped_backend
|
||||
logger.debug(
|
||||
"Replaced submodule %s with piecewise backend from cache",
|
||||
submod_name,
|
||||
)
|
||||
|
||||
if compilation_config.cudagraph_copy_inputs:
|
||||
sym_tensor_indices = state["sym_tensor_indices"]
|
||||
input_buffers = [
|
||||
torch.empty_like(
|
||||
state["example_inputs"][idx], device=vllm_config.device_config.device
|
||||
)
|
||||
for idx in sym_tensor_indices
|
||||
]
|
||||
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
|
||||
else:
|
||||
optimized_call = split_gm
|
||||
|
||||
fn = VllmSerializableFunction(
|
||||
**state,
|
||||
optimized_call=optimized_call,
|
||||
vllm_backend=None,
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
env_hash = hash_factors(envs.compile_factors())
|
||||
factors.append(env_hash)
|
||||
|
||||
# 1. factors come from the vllm_config (it mainly summarizes how the
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
|
||||
# 2. inductor factors if applicable
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
factors.extend(get_inductor_factors())
|
||||
|
||||
return factors
|
||||
|
||||
|
||||
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
|
||||
hash_content = []
|
||||
for filepath, content in items:
|
||||
hash_content.append(filepath)
|
||||
if filepath == "<string>":
|
||||
# This means the function was dynamically generated, with
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(content)
|
||||
result: str = safe_hash(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
return result
|
||||
|
||||
|
||||
def _compute_code_hash(files: set[str]) -> str:
|
||||
logger.debug(
|
||||
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
|
||||
)
|
||||
file_contents = {}
|
||||
for filepath in files:
|
||||
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
|
||||
if not os.path.isfile(filepath):
|
||||
file_contents[filepath] = ""
|
||||
else:
|
||||
with open(filepath) as f:
|
||||
file_contents[filepath] = f.read()
|
||||
return _compute_code_hash_with_content(file_contents)
|
||||
660
vllm/compilation/compiler_interface.py
Normal file
660
vllm/compilation/compiler_interface.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
The interface for a compiler that can be used by vLLM.
|
||||
"""
|
||||
|
||||
# The name of the compiler, e.g. inductor.
|
||||
# This is a class-level attribute.
|
||||
name: str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
"""
|
||||
when the vLLM process uses `cache_dir` as the cache directory,
|
||||
the compiler should initialize itself with the cache directory,
|
||||
e.g. by re-directing its own cache directory to a sub-directory.
|
||||
|
||||
prefix can be used in combination with cache_dir to figure out the base
|
||||
cache directory, e.g. there're multiple parts of model being compiled,
|
||||
but we want to share the same cache directory for all of them.
|
||||
|
||||
e.g.
|
||||
cache_dir = "/path/to/dir/backbone", prefix = "backbone"
|
||||
cache_dir = "/path/to/dir/eagle_head", prefix = "eagle_head"
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
"""
|
||||
Gather all the relevant information from the vLLM config,
|
||||
to compute a hash so that we can cache the compiled model.
|
||||
|
||||
See [`VllmConfig.compute_hash`][vllm.config.VllmConfig.compute_hash]
|
||||
to check what information
|
||||
is already considered by default. This function should only
|
||||
consider the information that is specific to the compiler.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a range. The `compile_range` specifies the range of the inputs,
|
||||
it could be concrete size (if compile_sizes is provided), e.g. [4, 4]
|
||||
or a range [5, 8].
|
||||
Right now we only support one variable in ranges for all inputs,
|
||||
which is the batchsize (number of tokens) during inference.
|
||||
|
||||
Dynamo will make sure `graph(*example_inputs)` is valid.
|
||||
|
||||
The function should return a compiled callable function, as well as
|
||||
a handle that can be used to directly load the compiled function.
|
||||
|
||||
The handle should be a plain Python object, preferably a string or a
|
||||
file path for readability.
|
||||
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
Raises an error if the handle is invalid.
|
||||
|
||||
The handle is the second return value of the `compile` function.
|
||||
"""
|
||||
raise NotImplementedError("caching is not supported")
|
||||
|
||||
|
||||
class AlwaysHitShapeEnv:
|
||||
"""
|
||||
Why do we need this class:
|
||||
|
||||
For normal `torch.compile` usage, every compilation will have
|
||||
one Dynamo bytecode compilation and one Inductor compilation.
|
||||
The Inductor compilation happens under the context of the
|
||||
Dynamo bytecode compilation, and that context is used to
|
||||
determine the dynamic shape information, etc.
|
||||
|
||||
For our use case, we only run Dynamo bytecode compilation once,
|
||||
and run Inductor compilation multiple times with different shapes
|
||||
plus a general shape. The compilation for specific shapes happens
|
||||
outside of the context of the Dynamo bytecode compilation. At that
|
||||
time, we don't have shape environment to provide to Inductor, and
|
||||
it will fail the Inductor code cache lookup.
|
||||
|
||||
By providing a dummy shape environment that always hits, we can
|
||||
make the Inductor code cache lookup always hit, and we can
|
||||
compile the graph for different shapes as needed.
|
||||
|
||||
The following dummy methods are obtained by trial-and-error
|
||||
until it works.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.guards: list[Any] = []
|
||||
|
||||
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
|
||||
return True
|
||||
|
||||
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
|
||||
return []
|
||||
|
||||
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> list[Any]:
|
||||
factors: list[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
def is_compile_cache_enabled(
|
||||
vllm_additional_inductor_config: dict[str, Any],
|
||||
) -> bool:
|
||||
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
|
||||
"force_disable_caches", False
|
||||
)
|
||||
|
||||
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
|
||||
# with torch.compiler.config.force_disable_caches when minimum PyTorch
|
||||
# version reaches 2.10
|
||||
return (
|
||||
not envs.VLLM_DISABLE_COMPILE_CACHE
|
||||
and not torch._inductor.config.force_disable_caches
|
||||
and not vllm_inductor_config_disable_cache
|
||||
)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_USE_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
if compile_range.is_single_size():
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_graph"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
supports_aot = is_torch_equal_or_newer("2.10.0")
|
||||
|
||||
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
logger.error(
|
||||
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
|
||||
"is enabled but PyTorch version does not support 'aot' "
|
||||
"parameter in standalone_compile. This requires PyTorch "
|
||||
"2.10.0+. Falling back to non-AOT mode."
|
||||
)
|
||||
|
||||
compile_kwargs = {
|
||||
"dynamic_shapes": dynamic_shapes,
|
||||
"options": {
|
||||
"config_patches": current_config,
|
||||
},
|
||||
}
|
||||
|
||||
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
|
||||
# only add 'aot' parameter if both supported and enabled...
|
||||
# this will set bundled_autograd_cache
|
||||
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
|
||||
if use_aot:
|
||||
compile_kwargs["aot"] = True # type: ignore[assignment]
|
||||
|
||||
# Inductor's pre-grad passes don't do anything for vLLM.
|
||||
# The pre-grad passes get run even on cache-hit and negatively impact
|
||||
# vllm cold compile times by O(1s)
|
||||
# Can remove this after the following issue gets fixed
|
||||
# https://github.com/pytorch/pytorch/issues/174502
|
||||
if envs.VLLM_ENABLE_PREGRAD_PASSES:
|
||||
ctx: Any = contextlib.nullcontext()
|
||||
else:
|
||||
ctx = patch(
|
||||
"torch._inductor.compile_fx._recursive_pre_grad_passes",
|
||||
lambda gm, _: gm,
|
||||
)
|
||||
with ctx:
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
assert isinstance(compiled_graph, AOTCompiledArtifact)
|
||||
assert hasattr(compiled_graph, "serialize")
|
||||
# just return the compiled graph and a key
|
||||
# since we can serialize the bytes using to_bytes
|
||||
# and reload it using the key when reading
|
||||
return compiled_graph, None
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
|
||||
def is_saveable_2_10(compiled_artifact):
|
||||
# can just use compiled_artifact.is_saveable in 2.11
|
||||
if compiled_artifact._artifacts is None:
|
||||
return False
|
||||
_, cache_info = compiled_artifact._artifacts
|
||||
return len(cache_info.aot_autograd_artifacts) == 1
|
||||
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
if not is_saveable_2_10(compiled_graph):
|
||||
raise RuntimeError(
|
||||
"The compiled artifact is not serializable. This usually means "
|
||||
"that the model code has something that is not serializable "
|
||||
"by torch.compile in it. You can fix this by either "
|
||||
"figuring out what is not serializable and rewriting it, "
|
||||
"filing a bug report, "
|
||||
"or suppressing this error by "
|
||||
"disabling vLLM's compilation cache via "
|
||||
"VLLM_DISABLE_COMPILE_CACHE=1 "
|
||||
"(this will greatly increase vLLM server warm start times)."
|
||||
)
|
||||
compiled_graph.save(path=path, format=self.save_format)
|
||||
compilation_counter.num_compiled_artifacts_saved += 1
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format=self.save_format
|
||||
)
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
|
||||
) -> None:
|
||||
self.cache_dir = cache_dir
|
||||
self.prefix = prefix
|
||||
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
|
||||
if disable_cache:
|
||||
return
|
||||
# redirect the cache directory to a subdirectory
|
||||
# set flags so that Inductor and Triton store their cache
|
||||
# in the cache_dir, then users only need to copy the cache_dir
|
||||
# to another machine to reuse the cache.
|
||||
inductor_cache = os.path.join(self.base_cache_dir, "inductor_cache")
|
||||
os.makedirs(inductor_cache, exist_ok=True)
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
|
||||
triton_cache = os.path.join(self.base_cache_dir, "triton_cache")
|
||||
os.makedirs(triton_cache, exist_ok=True)
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
set_inductor_config(current_config, compile_range)
|
||||
set_functorch_config()
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
graph = copy.deepcopy(graph)
|
||||
|
||||
# it's the first time we compile this graph
|
||||
# the assumption is that we don't have nested Inductor compilation.
|
||||
# compiled_fx_graph_hash will only be called once, and we can hook
|
||||
# it to get the hash of the compiled graph directly.
|
||||
|
||||
hash_str, file_path = None, None
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
|
||||
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
|
||||
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
inductor_compiled_graph = output
|
||||
if inductor_compiled_graph is not None:
|
||||
nonlocal file_path
|
||||
compiled_fn = inductor_compiled_graph.current_callable
|
||||
file_path = compiled_fn.__code__.co_filename # noqa
|
||||
if (
|
||||
not file_path.startswith(self.base_cache_dir)
|
||||
and compiled_fn.__closure__ is not None
|
||||
):
|
||||
# hooked in the align_inputs_from_check_idxs function
|
||||
# in torch/_inductor/utils.py
|
||||
for cell in compiled_fn.__closure__:
|
||||
if not callable(cell.cell_contents):
|
||||
continue
|
||||
code = cell.cell_contents.__code__
|
||||
if code.co_filename.startswith(self.base_cache_dir):
|
||||
# this is the real file path
|
||||
# compiled from Inductor
|
||||
file_path = code.co_filename
|
||||
break
|
||||
hash_str = inductor_compiled_graph._fx_graph_cache_key
|
||||
return output
|
||||
|
||||
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
|
||||
out = compiled_fx_graph_hash(*args, **kwargs)
|
||||
nonlocal hash_str
|
||||
hash_str = out[0]
|
||||
return out
|
||||
|
||||
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
|
||||
# no error means it can be cached.
|
||||
# Inductor refuses to cache the graph outside of Dynamo
|
||||
# tracing context, and also disables caching for graphs
|
||||
# with high-order ops.
|
||||
# For vLLM, in either case, we want to cache the graph.
|
||||
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
|
||||
return
|
||||
|
||||
def _get_shape_env() -> AlwaysHitShapeEnv:
|
||||
return AlwaysHitShapeEnv()
|
||||
|
||||
with ExitStack() as stack:
|
||||
# for hijacking the hash of the compiled graph
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.compiled_fx_graph_hash",
|
||||
hijack_compiled_fx_graph_hash,
|
||||
)
|
||||
)
|
||||
|
||||
# for providing a dummy shape environment
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
_get_shape_env,
|
||||
)
|
||||
)
|
||||
|
||||
# for forcing the graph to be cached
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._check_can_cache",
|
||||
_check_can_cache,
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
stack.enter_context(self.metrics_context())
|
||||
|
||||
# Disable remote caching. When these are on, on remote cache-hit,
|
||||
# the monkey-patched functions never actually get called.
|
||||
# vLLM today assumes and requires the monkey-patched functions to
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
example_inputs,
|
||||
inner_compile=hijacked_compile_fx_inner,
|
||||
config_patches=current_config,
|
||||
)
|
||||
|
||||
# Turn off the checks if we disable the compilation cache.
|
||||
if is_compile_cache_enabled(compiler_config):
|
||||
if hash_str is None:
|
||||
raise RuntimeError(
|
||||
"vLLM failed to compile the model. The most "
|
||||
"likely reason for this is that a previous compilation "
|
||||
"failed, leading to a corrupted compilation artifact. "
|
||||
"We recommend trying to "
|
||||
"remove ~/.cache/vllm/torch_compile_cache and try again "
|
||||
"to see the real issue. "
|
||||
)
|
||||
assert file_path is not None, (
|
||||
"failed to get the file path of the compiled graph"
|
||||
)
|
||||
return compiled_graph, (hash_str, file_path)
|
||||
|
||||
def load(
|
||||
self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
hash_str = handle[0]
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
|
||||
with ExitStack() as exit_stack:
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._inductor.codecache.FxGraphCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
|
||||
if hasattr(AOTAutogradCache, "_get_shape_env"):
|
||||
exit_stack.enter_context(
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
|
||||
lambda *args, **kwargs: AlwaysHitShapeEnv(),
|
||||
)
|
||||
)
|
||||
|
||||
# Dynamo metrics context, see method for more details.
|
||||
exit_stack.enter_context(self.metrics_context())
|
||||
|
||||
from torch._inductor.output_code import CompiledFxGraphConstantsWithGm
|
||||
|
||||
constants = CompiledFxGraphConstantsWithGm(graph)
|
||||
inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
|
||||
hash_str, example_inputs, True, None, constants
|
||||
)
|
||||
assert inductor_compiled_graph is not None, (
|
||||
"Inductor cache lookup failed. Please remove "
|
||||
f"the cache directory and try again." # noqa
|
||||
)
|
||||
|
||||
# Inductor calling convention (function signature):
|
||||
# f(list) -> tuple
|
||||
# Dynamo calling convention (function signature):
|
||||
# f(*args) -> Any
|
||||
|
||||
# need to know if the graph returns a tuple
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
|
||||
# convert args to list
|
||||
list_args = list(args)
|
||||
graph_output = inductor_compiled_graph(list_args)
|
||||
# unpack the tuple if needed
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph
|
||||
|
||||
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
|
||||
"""
|
||||
This method returns the Dynamo metrics context (if it exists,
|
||||
otherwise a null context). It is used by various compile components.
|
||||
Present in torch>=2.6, it's used inside FxGraphCache in
|
||||
torch==2.6 (but not after). It might also be used in various other
|
||||
torch.compile internal functions.
|
||||
|
||||
Because it is re-entrant, we always set it (even if entering via Dynamo
|
||||
and the context was already entered). We might want to revisit if it
|
||||
should be set at a different mode of compilation.
|
||||
|
||||
This is likely a bug in PyTorch: public APIs should not rely on
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
|
||||
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
|
||||
if compile_range.is_single_size():
|
||||
# for a specific batch size, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = envs.VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE
|
||||
config["coordinate_descent_tuning"] = (
|
||||
envs.VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING
|
||||
)
|
||||
|
||||
|
||||
def set_functorch_config() -> None:
|
||||
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
return graph, None
|
||||
50
vllm/compilation/counter.py
Normal file
50
vllm/compilation/counter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CompilationCounter:
|
||||
num_models_seen: int = 0
|
||||
num_graphs_seen: int = 0
|
||||
# including the splitting ops
|
||||
num_piecewise_graphs_seen: int = 0
|
||||
# not including the splitting ops
|
||||
num_piecewise_capturable_graphs_seen: int = 0
|
||||
num_backend_compilations: int = 0
|
||||
# Number of gpu_model_runner attempts to trigger CUDAGraphs capture
|
||||
num_gpu_runner_capture_triggers: int = 0
|
||||
# Number of CUDAGraphs captured
|
||||
num_cudagraph_captured: int = 0
|
||||
# InductorAdapter.compile calls
|
||||
num_inductor_compiles: int = 0
|
||||
# EagerAdapter.compile calls
|
||||
num_eager_compiles: int = 0
|
||||
# The number of time vLLM's compiler cache entry was updated
|
||||
num_cache_entries_updated: int = 0
|
||||
# The number of standalone_compile compiled artifacts saved
|
||||
num_compiled_artifacts_saved: int = 0
|
||||
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
|
||||
stock_torch_compile_count: int = 0
|
||||
|
||||
def clone(self) -> "CompilationCounter":
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@contextmanager
|
||||
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
|
||||
old = self.clone()
|
||||
yield
|
||||
for k, v in kwargs.items():
|
||||
assert getattr(self, k) - getattr(old, k) == v, (
|
||||
f"{k} not as expected, before it is {getattr(old, k)}"
|
||||
f", after it is {getattr(self, k)}, "
|
||||
f"expected diff is {v}"
|
||||
)
|
||||
|
||||
|
||||
compilation_counter = CompilationCounter()
|
||||
332
vllm/compilation/cuda_graph.py
Normal file
332
vllm/compilation/cuda_graph.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CUDAGraphStat:
|
||||
num_unpadded_tokens: int
|
||||
num_padded_tokens: int
|
||||
num_paddings: int
|
||||
runtime_mode: str
|
||||
|
||||
|
||||
class CUDAGraphLogging:
|
||||
"""Aggregate and log cudagraph metrics"""
|
||||
|
||||
COLUMN_HEADERS = [
|
||||
"Unpadded Tokens",
|
||||
"Padded Tokens",
|
||||
"Num Paddings",
|
||||
"Runtime Mode",
|
||||
"Count",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
|
||||
) -> None:
|
||||
self.reset()
|
||||
self.cg_mode = str(cg_mode)
|
||||
self.cg_capture_sizes = str(cg_capture_sizes or [])
|
||||
|
||||
self.settings_header = (
|
||||
"**CUDAGraph Config Settings:**\n\n"
|
||||
f"- Mode: {self.cg_mode}\n"
|
||||
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
|
||||
"**CUDAGraph Stats:**\n\n"
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.stats: list[CUDAGraphStat] = []
|
||||
|
||||
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
|
||||
self.stats.append(cudagraph_stat)
|
||||
|
||||
def generate_metric_table(self) -> str:
|
||||
stats_counts = Counter(self.stats)
|
||||
|
||||
# Convert stats to rows of strings, in descending order of observed frequencies
|
||||
rows = []
|
||||
for stat, count in sorted(
|
||||
stats_counts.items(), key=lambda item: item[1], reverse=True
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
str(stat.num_unpadded_tokens),
|
||||
str(stat.num_padded_tokens),
|
||||
str(stat.num_paddings),
|
||||
stat.runtime_mode,
|
||||
str(count),
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate column widths (max of header and data)
|
||||
col_widths = []
|
||||
for i, header_text in enumerate(self.COLUMN_HEADERS):
|
||||
max_width = len(header_text)
|
||||
for row in rows:
|
||||
max_width = max(max_width, len(row[i]))
|
||||
col_widths.append(max_width)
|
||||
|
||||
table_header_list = [
|
||||
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
|
||||
]
|
||||
table_header = "| " + " | ".join(table_header_list) + " |\n"
|
||||
|
||||
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
|
||||
|
||||
# Create data rows with proper alignment
|
||||
data_rows = []
|
||||
for row in rows:
|
||||
formatted_row = [
|
||||
str(val).ljust(width) for val, width in zip(row, col_widths)
|
||||
]
|
||||
data_rows.append("| " + " | ".join(formatted_row) + " |")
|
||||
|
||||
return (
|
||||
self.settings_header
|
||||
+ table_header
|
||||
+ table_separator
|
||||
+ "\n".join(data_rows)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
|
||||
if not self.stats:
|
||||
return
|
||||
log_fn(self.generate_metric_table())
|
||||
self.reset()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphEntry:
|
||||
batch_descriptor: BatchDescriptor
|
||||
cudagraph: torch.cuda.CUDAGraph | None = None
|
||||
output: Any | None = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: list[int] | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CUDAGraphOptions:
|
||||
debug_log_enable: bool = True
|
||||
gc_disable: bool = False
|
||||
weak_ref_output: bool = True
|
||||
|
||||
|
||||
class CUDAGraphWrapper:
|
||||
"""Wraps a runnable to add CUDA graph capturing and replaying ability. And
|
||||
provide attribute access to the underlying `runnable` via `__getattr__`.
|
||||
|
||||
The workflow of this wrapper in the cudagraph dispatching is as follows:
|
||||
1. At initialization, a runtime mode is assigned to the wrapper (FULL or
|
||||
PIECEWISE).
|
||||
2. At runtime, the wrapper receives a runtime_mode and a
|
||||
batch_descriptor(key) from the forward context and blindly trust them
|
||||
for cudagraph dispatching.
|
||||
3. If runtime_mode is NONE or runtime_mode does not match the mode of the
|
||||
wrapper, just call the runnable directly.
|
||||
4. Otherwise, i.e., the runtime_mode matches the mode of the wrapper,
|
||||
the wrapper will perform cudagraph capture(if key does not exist, create
|
||||
a new entry and cache it) or replay (if key exists in the cache).
|
||||
|
||||
Note: CUDAGraphWrapper does not store persistent buffers or copy any
|
||||
runtime inputs into that buffers for replay. We assume implementing them
|
||||
is done outside of the wrapper. That is because we do not make any
|
||||
assumption on the dynamic shape (batch size) of the runtime inputs, as a
|
||||
trade-off for staying orthogonal to compilation logic. Nevertheless,
|
||||
tracing and checking the input addresses to be consistent during replay is
|
||||
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnable: Callable[..., Any],
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
) -> None:
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.first_run_finished = False
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
self.cudagraph_options = cudagraph_options
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# cudagraphs for.
|
||||
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
# allow accessing the attributes of the runnable.
|
||||
if hasattr(self.runnable, key):
|
||||
return getattr(self.runnable, key)
|
||||
raise AttributeError(
|
||||
f"Attribute {key} not exists in the runnable of "
|
||||
f"cudagraph wrapper: {self.runnable}"
|
||||
)
|
||||
|
||||
def unwrap(self) -> Callable[..., Any]:
|
||||
# in case we need to access the original runnable.
|
||||
return self.runnable
|
||||
|
||||
def weak_ref_tensors_with_intermediate(self, output):
|
||||
if isinstance(output, IntermediateTensors):
|
||||
intermediate_states = IntermediateTensors(
|
||||
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
|
||||
return intermediate_states
|
||||
return weak_ref_tensors(output)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
|
||||
forward_context = get_forward_context()
|
||||
batch_descriptor = forward_context.batch_descriptor
|
||||
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
|
||||
|
||||
if (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||
or cudagraph_runtime_mode != self.runtime_mode
|
||||
):
|
||||
# CUDAGraphMode.NONE could mean the profile run, a warmup run, or
|
||||
# running without cudagraphs.
|
||||
# We do not trigger capture/replay if the runtime mode is not
|
||||
# matches. This enables properly dispatching to the correct
|
||||
# CUDAGraphWrapper when nesting multiple instances with different
|
||||
# runtime modes.
|
||||
return self.runnable(*args, **kwargs)
|
||||
|
||||
assert batch_descriptor is not None
|
||||
if batch_descriptor not in self.concrete_cudagraph_entries:
|
||||
# create a new entry for this batch descriptor
|
||||
self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry(
|
||||
batch_descriptor=batch_descriptor
|
||||
)
|
||||
|
||||
entry = self.concrete_cudagraph_entries[batch_descriptor]
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if self.cudagraph_options.debug_log_enable:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every
|
||||
# shape. E.g. we only log it for the first subgraph in
|
||||
# piecewise mode.
|
||||
logger.debug(
|
||||
"Capturing a cudagraph on (%s,%s)",
|
||||
self.runtime_mode.name,
|
||||
entry.batch_descriptor,
|
||||
)
|
||||
# validate that cudagraph capturing is legal at this point.
|
||||
validate_cudagraph_capturing_enabled()
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if self.cudagraph_options.gc_disable:
|
||||
# during every model forward for piecewise cudagraph
|
||||
# mode, we will capture many pieces of cudagraphs
|
||||
# (roughly one per layer). running gc again and again
|
||||
# across layers will make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
if self.graph_pool is not None:
|
||||
set_graph_pool_id(self.graph_pool)
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(
|
||||
cudagraph,
|
||||
pool=self.graph_pool,
|
||||
stream=current_stream(),
|
||||
):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
# Join offloader's copy stream after forward to avoid
|
||||
# unjoined stream error. The last layer's start_prefetch
|
||||
# forks copy_stream, but wait_prefetch only happens in
|
||||
# the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph in piecewise cuadgraph mode, because
|
||||
# the output of the last graph will not be used by
|
||||
# any other cuda graph.
|
||||
# output = weak_ref_tensors(output)
|
||||
output = self.weak_ref_tensors_with_intermediate(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
# entry.output = weak_ref_tensors(output)
|
||||
entry.output = self.weak_ref_tensors_with_intermediate(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_captured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
f"Input addresses for cudagraphs are different "
|
||||
f"during replay. Expected {entry.input_addresses}, "
|
||||
f"got {new_input_addresses}"
|
||||
)
|
||||
|
||||
# Sync offloader before replay - ensures any external dependencies
|
||||
# from pre-capture prefetches are satisfied.
|
||||
get_offloader().sync_prev_onload()
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
657
vllm/compilation/decorators.py
Normal file
657
vllm/compilation/decorators.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Only added on nightly/2.10 so wrap
|
||||
try:
|
||||
from torch._dynamo.package import SourceInfo
|
||||
except ImportError:
|
||||
# Fallback for old versions not supporting
|
||||
SourceInfo = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=nn.Module)
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
a support_torch_compile decorator, but we don't want to
|
||||
compile the class `cls` that inherits the parent class.
|
||||
This only ignores compiling the forward of the class the
|
||||
decorator is applied to.
|
||||
|
||||
If the parent has ignore_torch_compile but the child has
|
||||
support_torch_compile, the child will still be compiled.
|
||||
|
||||
If the class has one or more submodules
|
||||
that have support_torch_compile decorator applied, compile will
|
||||
not be ignored for those submodules.
|
||||
"""
|
||||
setattr(cls, IGNORE_COMPILE_KEY, True)
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
return getattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: type[_T] | None = None,
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> Callable[[type[_T]], type[_T]] | type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
Usage 1: use directly as a decorator without arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
Usage 2: use as a decorator with arguments:
|
||||
|
||||
```python
|
||||
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
||||
class MyModel(nn.Module):
|
||||
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ...
|
||||
```
|
||||
|
||||
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
||||
dimensions of the argument. The dynamic dimensions can be either a single
|
||||
integer or a list of integers.
|
||||
|
||||
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
||||
of the `forward` method, based on the following default rules:
|
||||
|
||||
- if the argument is annotated as `torch.Tensor` or
|
||||
`Optional[torch.Tensor]`, the first dimension will be
|
||||
marked as dynamic.
|
||||
- if the argument is annotated as `IntermediateTensors`, the first
|
||||
dimension of all the tensors in the intermediate tensors
|
||||
will be marked as dynamic.
|
||||
|
||||
During runtime, when we actually mark dimensions of tensors,
|
||||
it depends on the value of arguments:
|
||||
|
||||
- if it is a single integer (can be negative), the corresponding dimension
|
||||
of the argument will be marked as dynamic.
|
||||
- if it is `None`, ignored.
|
||||
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
||||
tensors will be marked as dynamic.
|
||||
- otherwise, it will raise an error.
|
||||
|
||||
NOTE: if an argument is `None`, it should always be passed as `None` during
|
||||
the lifetime of the model, otherwise, it cannot be captured as a single
|
||||
computation graph.
|
||||
|
||||
`enable_if` is a function that takes a `VllmConfig` object as input and
|
||||
returns a boolean value indicating whether to compile the model or not.
|
||||
This is useful if you want to compile the model only when certain
|
||||
conditions are met.
|
||||
|
||||
`mark_unbacked_dims` is a dictionary that maps argument names with a dynamic
|
||||
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
|
||||
such as for vision model compilation
|
||||
|
||||
`shape_invariants` is a function that gets compiled right before forward.
|
||||
The function should have the torch._check calls that are needed to set
|
||||
the relationships between different input sizes. For example:
|
||||
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
|
||||
This enforces constraints on the symbolic shapes without hardcoding
|
||||
specific values. It is needed for some models to avoid data dependent
|
||||
errors.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: type[_T]) -> type[_T]:
|
||||
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
|
||||
# to avoid too much indentation for `_support_torch_compile`
|
||||
if not hasattr(cls, "forward"):
|
||||
raise TypeError("decorated class should have a forward method.")
|
||||
sig = inspect.signature(cls.forward)
|
||||
inferred_dynamic_arg_dims = dynamic_arg_dims
|
||||
if inferred_dynamic_arg_dims is None:
|
||||
inferred_dynamic_arg_dims = {}
|
||||
for k, v in sig.parameters.items():
|
||||
if v.annotation in [
|
||||
torch.Tensor,
|
||||
torch.Tensor | None,
|
||||
IntermediateTensors,
|
||||
IntermediateTensors | None,
|
||||
]:
|
||||
inferred_dynamic_arg_dims[k] = 0
|
||||
|
||||
logger.debug(
|
||||
("Inferred dynamic dimensions for forward method of %s: %s"),
|
||||
cls,
|
||||
list(inferred_dynamic_arg_dims.keys()),
|
||||
)
|
||||
|
||||
if len(inferred_dynamic_arg_dims) == 0:
|
||||
raise ValueError(
|
||||
"No dynamic dimensions found in the forward method of "
|
||||
f"{cls}. Please provide dynamic_arg_dims explicitly."
|
||||
)
|
||||
|
||||
for k in inferred_dynamic_arg_dims:
|
||||
if k not in sig.parameters:
|
||||
raise ValueError(
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(
|
||||
cls,
|
||||
inferred_dynamic_arg_dims,
|
||||
mark_unbacked_dims,
|
||||
enable_if,
|
||||
shape_invariants,
|
||||
)
|
||||
|
||||
if cls is not None:
|
||||
# use `support_torch_compile` as a decorator without arguments
|
||||
assert isinstance(cls, type)
|
||||
return cls_decorator_helper(cls)
|
||||
|
||||
return cls_decorator_helper
|
||||
|
||||
|
||||
def _model_hash_key(fn: Callable[..., Any]) -> str:
|
||||
import vllm
|
||||
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(vllm.__version__.encode())
|
||||
sha256_hash.update(fn.__qualname__.encode())
|
||||
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def _verify_source_unchanged(
|
||||
source_info: "SourceInfo", vllm_config: VllmConfig
|
||||
) -> None:
|
||||
from .caching import _compute_code_hash, _compute_code_hash_with_content
|
||||
|
||||
file_contents = {}
|
||||
for source in source_info.inlined_sources:
|
||||
module = sys.modules[source.module]
|
||||
file = inspect.getfile(module)
|
||||
vllm_config.compilation_config.traced_files.add(file)
|
||||
file_contents[file] = source.content
|
||||
expected_checksum = _compute_code_hash_with_content(file_contents)
|
||||
actual_checksum = _compute_code_hash(set(file_contents.keys()))
|
||||
if expected_checksum != actual_checksum:
|
||||
raise RuntimeError(
|
||||
"Source code has changed since the last compilation. Recompiling the model."
|
||||
)
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: type[_T],
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWithNoGuardsWrapper
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
setattr(cls, IGNORE_COMPILE_KEY, False)
|
||||
|
||||
def __init__(
|
||||
self: _T,
|
||||
*,
|
||||
vllm_config: VllmConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if vllm_config is None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
# NOTE: to support multimodal models (such as encoder),
|
||||
# we may not have vllm_config so we may need to patch
|
||||
# it
|
||||
sig = inspect.signature(old_init)
|
||||
if "vllm_config" in sig.parameters:
|
||||
kwargs["vllm_config"] = vllm_config
|
||||
if "prefix" in sig.parameters:
|
||||
kwargs["prefix"] = prefix
|
||||
old_init(self, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = (
|
||||
self.compilation_config.mode
|
||||
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
or not enable_compile
|
||||
)
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
self._check_shape_invariants = shape_invariants
|
||||
self.was_aot_compile_fn_loaded_from_disk = False
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
|
||||
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(
|
||||
mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
for dim in dims:
|
||||
torch._dynamo.decorators.mark_unbacked(
|
||||
arg, dim, hint_override=arg.size()[dim]
|
||||
)
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# If skip_compiled is set, bypass compiled model call. This is used e.g. for
|
||||
# enc-dec models where tensor shapes/types vary across invocations, preventing
|
||||
# the capture of a single computational graph.
|
||||
if is_forward_context_available() and get_forward_context().skip_compiled:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, call it with partition wrapper context.
|
||||
# The partition wrapper must be active at runtime for CUDA graph
|
||||
# capture to work correctly with inductor graph partitioning.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
ds_type = self.compilation_config.dynamic_shapes_config.type
|
||||
cache_dir = None
|
||||
aot_compilation_path = None
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
"""
|
||||
When using torch.compile in AOT mode, we store the cache artifacts
|
||||
under VLLM_CACHE_ROOT/torch_compile_cache/torch_aot_compile/{hash}
|
||||
The {hash} contains all of the factors except for the source files
|
||||
being traced through, because we don't actually know which source
|
||||
files to check at this point (before dynamo runs).
|
||||
On loading we will actually look at the source files being traced
|
||||
through. If any source file have changed (compared with the
|
||||
serialized backend artifacts), then we need to generate a new AOT
|
||||
compile artifact from scratch.
|
||||
"""
|
||||
from .caching import aot_compile_hash_factors
|
||||
|
||||
factors: list[str] = aot_compile_hash_factors(self.vllm_config)
|
||||
|
||||
factors.append(_model_hash_key(self.forward))
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_compile_cache",
|
||||
"torch_aot_compile",
|
||||
hash_key,
|
||||
)
|
||||
|
||||
rank = self.vllm_config.parallel_config.rank
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_index
|
||||
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
aot_compilation_path = os.path.join(cache_dir, "model")
|
||||
try:
|
||||
with (
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
):
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
loaded_fn = torch.compiler.load_compiled_function(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
self.was_aot_compile_fn_loaded_from_disk = True
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
if isinstance(e, EOFError):
|
||||
message = "Compile cache file corrupted."
|
||||
else:
|
||||
message = str(e)
|
||||
logger.warning(
|
||||
"Compiling model again due to a load failure from %s, "
|
||||
"reason: %s",
|
||||
aot_compilation_path,
|
||||
message,
|
||||
)
|
||||
if envs.VLLM_FORCE_AOT_LOAD:
|
||||
raise e
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
# Apply partition wrapper context for proper CUDA graph capture
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert (
|
||||
not envs.VLLM_USE_AOT_COMPILE
|
||||
or self.vllm_config.compilation_config.backend == "eager"
|
||||
)
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
# This is the path for the first compilation.
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(
|
||||
self,
|
||||
ds_type,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
original_code_object = self.original_code_object()
|
||||
logger.debug("Start compiling function %s", original_code_object)
|
||||
|
||||
# we do not want tp delete the original code object entries since
|
||||
# we depend on them now to look up cached compiled functions.
|
||||
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.compilation_config.traced_files.add(original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
|
||||
def patched_inline_call(self_: Any) -> Any:
|
||||
code = self_.f_code
|
||||
self.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
# Prepare backed_size_oblivious config patch if needed
|
||||
fx_config_patches = {}
|
||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
# Prepare inductor config patches
|
||||
# assume_32bit_indexing is only available in torch 2.10.0+
|
||||
inductor_config_patches = {}
|
||||
if is_torch_equal_or_newer("2.10.0"):
|
||||
inductor_config_patches["assume_32bit_indexing"] = (
|
||||
self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
torch._inductor.config.patch(**inductor_config_patches),
|
||||
):
|
||||
use_aot_compile = envs.VLLM_USE_AOT_COMPILE
|
||||
if self.vllm_config.compilation_config.backend == "eager":
|
||||
logger.warning("Detected eager backend, disabling AOT compile.")
|
||||
use_aot_compile = False
|
||||
if use_aot_compile:
|
||||
from vllm.compilation.backends import set_on_compilation_complete
|
||||
|
||||
# store the path for saving after warmup
|
||||
self._aot_compilation_path = aot_compilation_path
|
||||
self._aot_cache_dir = cache_dir
|
||||
# set callback in context so it's available when compilation completes
|
||||
with set_on_compilation_complete(self.save_aot_compiled_function):
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
# triggers VllmSerializableFunction.serialize()
|
||||
def save_aot_compiled_function(self: type[_T]) -> None:
|
||||
if self.was_aot_compile_fn_loaded_from_disk:
|
||||
logger.debug("AOT compiled function was loaded from cache, skipping save")
|
||||
return
|
||||
|
||||
assert (
|
||||
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
|
||||
)
|
||||
|
||||
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
|
||||
try:
|
||||
os.makedirs(self._aot_cache_dir, exist_ok=True)
|
||||
# File saving should be atomic, so we will save to a temporary location
|
||||
# first. Should be upstreamed to PyTorch 2.12 as well.
|
||||
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
|
||||
self.aot_compiled_fn.save_compiled_function(tmp_file)
|
||||
os.replace(tmp_file, self._aot_compilation_path)
|
||||
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"unable to save AOT compiled function to %s: %s",
|
||||
self._aot_compilation_path,
|
||||
e,
|
||||
)
|
||||
|
||||
cls.__call__ = __call__
|
||||
cls.save_aot_compiled_function = save_aot_compiled_function
|
||||
return cls
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def maybe_use_cudagraph_partition_wrapper(
|
||||
vllm_config: VllmConfig,
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager to set/unset customized cudagraph partition wrappers.
|
||||
|
||||
If we're using Inductor-based graph partitioning, we currently have the
|
||||
whole `fx.Graph` before Inductor lowering and the piecewise
|
||||
splitting happens after all graph passes and fusions. Here, we add
|
||||
a custom hook for Inductor to wrap each partition with our static
|
||||
graph wrapper class to maintain more control over static graph
|
||||
capture and replay.
|
||||
"""
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
def customized_cudagraph_wrapper(
|
||||
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
|
||||
) -> Any:
|
||||
partition_id = metadata.partition_index
|
||||
num_partitions = metadata.num_partitions
|
||||
return static_graph_wrapper_class(
|
||||
runnable=f,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=partition_id == 0,
|
||||
gc_disable=partition_id != 0,
|
||||
weak_ref_output=partition_id == num_partitions - 1,
|
||||
),
|
||||
)
|
||||
|
||||
torch._inductor.utils.set_customized_partition_wrappers(
|
||||
customized_cudagraph_wrapper
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
if (
|
||||
compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
63
vllm/compilation/monitor.py
Normal file
63
vllm/compilation/monitor.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.perf_counter()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug("Dumping depyf output to %s", path)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
total_compile_time: float = time.perf_counter() - torch_compile_start_time
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info_once(
|
||||
"torch.compile takes %.2f s in total",
|
||||
total_compile_time,
|
||||
scope="local",
|
||||
)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
|
||||
def validate_cudagraph_capturing_enabled() -> None:
|
||||
# used to monitor whether a cudagraph capturing is legal at runtime.
|
||||
# should be called before any cudagraph capturing.
|
||||
# if an illegal cudagraph capturing happens, raise an error.
|
||||
global cudagraph_capturing_enabled
|
||||
if not cudagraph_capturing_enabled:
|
||||
raise RuntimeError(
|
||||
"CUDA graph capturing detected at an inappropriate "
|
||||
"time. This operation is currently disabled."
|
||||
)
|
||||
|
||||
|
||||
def set_cudagraph_capturing_enabled(enabled: bool) -> None:
|
||||
global cudagraph_capturing_enabled
|
||||
cudagraph_capturing_enabled = enabled
|
||||
75
vllm/compilation/partition_rules.py
Normal file
75
vllm/compilation/partition_rules.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||
"""
|
||||
Check if a node should be split for dynamo graph partition.
|
||||
It operates on dynamo graph, so the node.target can be anything.
|
||||
We need to check and split only on OpOverload and OpOverloadPacket.
|
||||
"""
|
||||
|
||||
if node.op != "call_function":
|
||||
return False
|
||||
|
||||
target = node.target
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
# Example: "aten::add"
|
||||
return target._qualified_op_name in splitting_ops
|
||||
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Example: "aten::add"
|
||||
packet_name = target.name()
|
||||
|
||||
# Example: "aten::add.default"
|
||||
op_overload_name = f"{packet_name}.{target._overloadname}"
|
||||
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def inductor_partition_rule_context(
|
||||
splitting_ops: list[str] | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to temporarily register Inductor partition rules.
|
||||
|
||||
Registers custom partition rules for specified operators, forcing the
|
||||
Inductor scheduler to partition the graph at these operators. The rules
|
||||
are automatically restored to their previous state on exit.
|
||||
|
||||
Args:
|
||||
splitting_ops: List of operator names to partition on.
|
||||
"""
|
||||
if not splitting_ops:
|
||||
logger.debug("No partition ops provided; skipping rule registration.")
|
||||
yield
|
||||
return
|
||||
|
||||
# Save current state before registering
|
||||
|
||||
saved_splitting_ops: list[str] = list(
|
||||
torch._inductor.config.custom_should_partition_ops
|
||||
)
|
||||
torch._inductor.config.custom_should_partition_ops = splitting_ops
|
||||
|
||||
logger.debug(
|
||||
"Registered inductor partition rules for %d operators", len(splitting_ops)
|
||||
)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Clear and restore previous state
|
||||
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
|
||||
logger.debug("Restored previous partition rules state.")
|
||||
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
0
vllm/compilation/passes/fusion/__init__.py
Normal file
215
vllm/compilation/passes/fusion/act_quant_fusion.py
Normal file
215
vllm/compilation/passes/fusion/act_quant_fusion.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (
|
||||
PatternMatcherPass,
|
||||
fwd_only,
|
||||
register_replacement,
|
||||
)
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
FUSED_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501
|
||||
}
|
||||
silu_and_mul_nvfp4_quant_supported = current_platform.is_cuda() and hasattr(
|
||||
torch.ops._C, "silu_and_mul_nvfp4_quant"
|
||||
)
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
FUSED_OPS[kNvfp4Dynamic] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501
|
||||
|
||||
|
||||
class ActivationQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Activation+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
) -> None:
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
assert self.quant_key in FUSED_OPS, (
|
||||
f"unsupported fusion scheme {self.quant_key}"
|
||||
)
|
||||
self.FUSED_OP = FUSED_OPS[self.quant_key]
|
||||
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Fp8StaticQuant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kFp8StaticTensorSym)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
scale = self.quant_matcher.inputs()[1]
|
||||
return [
|
||||
*self.silu_and_mul_matcher.inputs(), # input
|
||||
scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
result_quant = self.quant_matcher(result_silu_mul, scale)
|
||||
return result_quant[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = input.shape[-1] // 2
|
||||
output_shape = input.shape[:-1] + (d,)
|
||||
result = torch.empty(
|
||||
output_shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return at[1]
|
||||
|
||||
inps = self.get_inputs()
|
||||
pattern(*inps)
|
||||
|
||||
register_replacement(pattern, replacement, inps, fwd_only, pm_pass)
|
||||
|
||||
|
||||
class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
Fusion for SiluMul+Nvfp4Quant Pattern
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(kNvfp4Dynamic)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
result = self.empty_quant(5, 32)
|
||||
output_scale = empty_i32(128, 4)
|
||||
input_ = empty_bf16(5, 64)
|
||||
scale = empty_fp32(1, 1)
|
||||
return [result, output_scale, input_, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_silu_mul = self.silu_and_mul_matcher(input)
|
||||
at = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=result,
|
||||
input=result_silu_mul,
|
||||
output_scale=output_scale,
|
||||
input_scale=scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
def replacement(
|
||||
result: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
result_block_scale=output_scale,
|
||||
input=input,
|
||||
input_global_scale=scale,
|
||||
)
|
||||
return at[1], at[2]
|
||||
|
||||
register_replacement(pattern, replacement, self.get_inputs(), fwd_only, pm_pass)
|
||||
|
||||
|
||||
class ActivationQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="activation_quant_fusion_pass"
|
||||
)
|
||||
|
||||
pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern()
|
||||
pattern_silu_mul_fp8.register(self.patterns)
|
||||
|
||||
if silu_and_mul_nvfp4_quant_supported:
|
||||
pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern()
|
||||
pattern_silu_mul_nvfp4.register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
ActivationQuantPattern,
|
||||
SiluMulFp8StaticQuantPattern,
|
||||
SiluMulNvfp4QuantPattern,
|
||||
)
|
||||
862
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Normal file
862
vllm/compilation/passes/fusion/allreduce_rms_fusion.py
Normal file
@@ -0,0 +1,862 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
from importlib.util import find_spec
|
||||
from types import ModuleType
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
)
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
flashinfer_comm: ModuleType | None = None
|
||||
if find_spec("flashinfer"):
|
||||
try:
|
||||
import flashinfer.comm as _flashinfer_comm
|
||||
|
||||
if hasattr(_flashinfer_comm, "allreduce_fusion") and hasattr(
|
||||
_flashinfer_comm, "create_allreduce_fusion_workspace"
|
||||
):
|
||||
flashinfer_comm = _flashinfer_comm
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer fused allreduce
|
||||
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 64, # 64MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 64, # 64MB
|
||||
4: 32, # 32MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
# Max size of the input tensor per world size per device capability
|
||||
# to use flashinfer one shot fused allreduce
|
||||
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
|
||||
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
|
||||
90: {
|
||||
2: 32, # 32MB
|
||||
4: 2, # 2MB
|
||||
8: 0.5, # 0.5MB
|
||||
},
|
||||
100: {
|
||||
2: 32, # 32MB
|
||||
4: 4, # 4MB
|
||||
8: 1, # 1MB
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if flashinfer_comm is not None:
|
||||
from vllm.distributed.device_communicators.flashinfer_all_reduce import (
|
||||
destroy_fi_ar_workspace,
|
||||
get_fi_ar_quant_workspace,
|
||||
get_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
initialize_fi_ar_workspace,
|
||||
)
|
||||
|
||||
ar_fusion_patterns = flashinfer_comm.AllReduceFusionPattern
|
||||
|
||||
MiB = 1024 * 1024
|
||||
|
||||
def call_trtllm_fused_allreduce_norm(
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
num_tokens, hidden_size = allreduce_in.shape
|
||||
element_size = allreduce_in.element_size()
|
||||
current_tensor_size = num_tokens * hidden_size * element_size
|
||||
max_tensor_size = max_token_num * hidden_size * element_size
|
||||
assert current_tensor_size <= max_tensor_size, (
|
||||
f"Current tensor size {current_tensor_size} is larger than "
|
||||
f"max token num {max_token_num} * hidden size {hidden_size} * "
|
||||
f"element size {element_size}"
|
||||
)
|
||||
curr_device = current_platform.get_device_capability()
|
||||
device_capability = curr_device.to_int() if curr_device is not None else None
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, # type: ignore[arg-type, unused-ignore]
|
||||
{},
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size is specified
|
||||
use_oneshot = (
|
||||
max_one_shot_size is None or current_tensor_size <= max_one_shot_size * MiB
|
||||
)
|
||||
|
||||
# Select workspace based on pattern: quant patterns use the
|
||||
# trtllm quant workspace, non-quant patterns use the primary workspace.
|
||||
if pattern_code in (
|
||||
ar_fusion_patterns.kARResidualRMSNormFP8Quant,
|
||||
ar_fusion_patterns.kARResidualRMSNormFP4Quant,
|
||||
):
|
||||
workspace = get_fi_ar_quant_workspace()
|
||||
else:
|
||||
workspace = get_fi_ar_workspace()
|
||||
assert workspace is not None, (
|
||||
"Flashinfer workspace must be initialized when using flashinfer"
|
||||
)
|
||||
assert flashinfer_comm is not None
|
||||
if norm_out is None:
|
||||
norm_out = allreduce_in
|
||||
residual_out = residual
|
||||
else:
|
||||
# return residual_out as allreduce_out with zeroed residual_in
|
||||
# as flashinfer does not support rms_norm
|
||||
# and allreduce_out together
|
||||
residual_out = allreduce_in
|
||||
|
||||
layout_code = None
|
||||
# layout_code only supported by trtllm backend
|
||||
if workspace.backend == "trtllm":
|
||||
# in vllm we only support swizzled layout
|
||||
layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4
|
||||
|
||||
flashinfer_comm.allreduce_fusion(
|
||||
input=allreduce_in,
|
||||
workspace=workspace,
|
||||
pattern=pattern_code,
|
||||
launch_with_pdl=launch_with_pdl,
|
||||
output=None,
|
||||
residual_out=residual_out,
|
||||
norm_out=norm_out,
|
||||
quant_out=quant_out,
|
||||
scale_out=scale_out,
|
||||
residual_in=residual,
|
||||
rms_gamma=rms_gamma,
|
||||
rms_eps=rms_eps,
|
||||
scale_factor=scale_factor,
|
||||
layout_code=layout_code,
|
||||
use_oneshot=use_oneshot,
|
||||
fp32_acc=fp32_acc,
|
||||
)
|
||||
|
||||
def call_trtllm_fused_allreduce_norm_fake(
|
||||
allreduce_in: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
rms_gamma: torch.Tensor,
|
||||
rms_eps: float,
|
||||
world_size: int,
|
||||
launch_with_pdl: bool,
|
||||
fp32_acc: bool,
|
||||
max_token_num: int,
|
||||
pattern_code: int,
|
||||
norm_out: torch.Tensor | None = None,
|
||||
quant_out: torch.Tensor | None = None,
|
||||
scale_out: torch.Tensor | None = None,
|
||||
scale_factor: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_trtllm_fused_allreduce_norm",
|
||||
op_func=call_trtllm_fused_allreduce_norm,
|
||||
mutates_args=[
|
||||
"allreduce_in",
|
||||
"residual",
|
||||
"norm_out",
|
||||
"quant_out",
|
||||
"scale_out",
|
||||
],
|
||||
fake_impl=call_trtllm_fused_allreduce_norm_fake,
|
||||
)
|
||||
flashinfer_trtllm_fused_allreduce_norm = (
|
||||
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFusedAllReduceParams:
|
||||
"""Parameters for FlashInfer fused allreduce operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
max_token_num: int = 1024,
|
||||
) -> None:
|
||||
self.world_size = world_size
|
||||
self.launch_with_pdl = True
|
||||
self.fp32_acc = True
|
||||
self.max_token_num = max_token_num
|
||||
|
||||
def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
|
||||
return {
|
||||
"world_size": self.world_size,
|
||||
"launch_with_pdl": self.launch_with_pdl,
|
||||
"fp32_acc": self.fp32_acc,
|
||||
"max_token_num": self.max_token_num,
|
||||
}
|
||||
|
||||
|
||||
# TODO(luka): unify
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class AllReduceRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm before attn in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(allreduce_output, weight)
|
||||
|
||||
return rms, allreduce_output
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=rms_result,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# rms_result, allreduce_in
|
||||
return allreduce[3], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
return rms, residual
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=None,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# allreduce_in, residual
|
||||
return allreduce[1], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
# Same pattern, but only return the output and not residual
|
||||
# (helpful for end of graph where residual is not used again)
|
||||
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
|
||||
|
||||
pm.register_replacement(
|
||||
first_return_only(pattern), # type: ignore[no-untyped-call]
|
||||
first_return_only(replacement), # type: ignore[no-untyped-call]
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
|
||||
),
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output
|
||||
return allreduce[4], allreduce[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static fp8 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, residual, weight = self.rmsnorm_matcher.inputs()
|
||||
_, scale = self.quant_matcher.inputs()
|
||||
|
||||
# input goes through allreduce first, always 16-bit
|
||||
return [residual, input.to(self.dtype), weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, res = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
|
||||
return quant, res
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=result_quant,
|
||||
scale_out=None,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
|
||||
),
|
||||
scale_factor=scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# quant_out, rms_norm_residual
|
||||
return allreduce[4], allreduce[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (without residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to allreduce + rmsnorm + quant before attn
|
||||
in the first Transformer block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
weight = torch.empty([16], device=self.device, dtype=self.dtype)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [input, quant_result, weight, input_global_scale, output_scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
all_reduce = tensor_model_parallel_all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
quant_result: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=result_rms,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
|
||||
),
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return allreduce[4], allreduce[1], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
"""
|
||||
This pattern replaces the allreduce + rms norm (with residual)
|
||||
+ static nvfp4 quant with fused flashinfer implementation.
|
||||
Applies to o_proj + rmsnorm after attn + quant and
|
||||
mlp + rmsnorm + quant before attn.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
allreduce_params: FlashInferFusedAllReduceParams,
|
||||
) -> None:
|
||||
super().__init__(dtype, device)
|
||||
self.epsilon = epsilon
|
||||
self.allreduce_params = allreduce_params
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
|
||||
input_global_scale = torch.empty(
|
||||
[1, 1], device=self.device, dtype=torch.float32
|
||||
)
|
||||
output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)
|
||||
|
||||
return [
|
||||
quant_result,
|
||||
residual,
|
||||
input,
|
||||
output_scale,
|
||||
weight,
|
||||
input_global_scale,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
allreduce_output = tensor_model_parallel_all_reduce(input)
|
||||
rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual)
|
||||
quant_out_tuple = auto_functionalized(
|
||||
STATIC_FP4_QUANT_OP,
|
||||
output=quant_result,
|
||||
input=rms,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_global_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
|
||||
# quant_out, allreduce_output, output_scale
|
||||
return quant_out_tuple[1], residual, quant_out_tuple[2]
|
||||
|
||||
def replacement(
|
||||
quant_result: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
residual=residual,
|
||||
norm_out=None,
|
||||
quant_out=quant_result,
|
||||
scale_out=output_scale,
|
||||
rms_gamma=weight,
|
||||
rms_eps=self.epsilon,
|
||||
# We don't use norm_out afterwards
|
||||
pattern_code=(
|
||||
flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
|
||||
),
|
||||
scale_factor=input_global_scale,
|
||||
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
|
||||
)
|
||||
# quant_out, rms_norm_residual, output_scale
|
||||
return allreduce[4], allreduce[2], allreduce[5]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.disabled = True
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size <= 1:
|
||||
logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.")
|
||||
return
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="all_reduce_fusion_pass"
|
||||
)
|
||||
if config.model_config is None:
|
||||
logger.warning_once(
|
||||
"AllReduce fusion pass is disabled for missing model_config."
|
||||
)
|
||||
return
|
||||
self.hidden_dim = config.model_config.get_hidden_size()
|
||||
self.group = get_tp_group().device_group
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
if flashinfer_comm is None:
|
||||
logger.warning(
|
||||
"Flashinfer is not installed or comm module not found, "
|
||||
"skipping allreduce fusion pass"
|
||||
)
|
||||
return
|
||||
max_size = config.compilation_config.pass_config.flashinfer_max_size(
|
||||
self.tp_size
|
||||
)
|
||||
if max_size is None:
|
||||
# Flashinfer doesn't support current world size
|
||||
logger.warning(
|
||||
"Flashinfer allreduce fusion is not supported for world size %s"
|
||||
" or max size is not provided",
|
||||
self.tp_size,
|
||||
)
|
||||
return
|
||||
element_size = torch.tensor([], dtype=self.model_dtype).element_size()
|
||||
self.max_token_num = max_size // (self.hidden_dim * element_size)
|
||||
# take the min to save workspace size and we'll never use more
|
||||
# than max_num_batched_tokens anyways
|
||||
self.max_token_num = min(
|
||||
self.max_token_num, config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
logger.debug_once(
|
||||
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
|
||||
"Maximal number of tokens used by "
|
||||
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
for workspace_init_fn in [
|
||||
initialize_fi_ar_workspace,
|
||||
initialize_fi_ar_quant_workspace,
|
||||
]:
|
||||
try:
|
||||
workspace_init_fn(
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
group=self.group,
|
||||
)
|
||||
except Exception as e:
|
||||
if "multicast" in str(e).lower():
|
||||
logger.warning(
|
||||
"AllReduce fusion pass is disabled: flashinfer workspace "
|
||||
"creation failed: %s. This is expected on GPUs without "
|
||||
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
|
||||
"Falling back to non-fused allreduce.",
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to initialize FlashInfer All Reduce workspace: %s. "
|
||||
"AllReduce fusion pass will be disabled.",
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
self.allreduce_params = FlashInferFusedAllReduceParams(
|
||||
world_size=self.tp_size,
|
||||
max_token_num=self.max_token_num,
|
||||
)
|
||||
|
||||
self.register_patterns()
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@enable_fake_mode
|
||||
def register_patterns(self) -> None:
|
||||
supports_quantization = get_fi_ar_quant_workspace() is not None
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
if supports_quantization:
|
||||
AllReduceFusedRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
if current_platform.has_device_capability(100):
|
||||
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
AllReduceFusedAddRMSNormPattern(
|
||||
epsilon,
|
||||
self.model_dtype,
|
||||
self.device,
|
||||
self.allreduce_params,
|
||||
).register(self.patterns)
|
||||
|
||||
# WARNING: This is a hack to clear the pattern matcher cache
|
||||
# and allow multiple values of epsilon.
|
||||
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||
|
||||
self.disabled = False
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
if self.disabled:
|
||||
logger.warning_once("AllReduce fusion pass is disabled.")
|
||||
return False
|
||||
return bool(compile_range.end <= self.max_token_num)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
if self.disabled:
|
||||
logger.debug("AllReduceFusionPass disabled")
|
||||
return
|
||||
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "disabled", True):
|
||||
return
|
||||
with contextlib.suppress(Exception):
|
||||
destroy_fi_ar_workspace()
|
||||
374
vllm/compilation/passes/fusion/attn_quant_fusion.py
Normal file
374
vllm/compilation/passes/fusion/attn_quant_fusion.py
Normal file
@@ -0,0 +1,374 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherQuantFP8
|
||||
from .rms_quant_fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
|
||||
|
||||
logger = init_logger(__name__)
|
||||
P = ParamSpec("P")
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
|
||||
RESHAPE_OP = torch.ops.aten.reshape.default
|
||||
|
||||
|
||||
class AttentionQuantPattern(ABC):
|
||||
"""
|
||||
The base class for Attn+Quant fusions.
|
||||
Should not be used directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
quant_key: QuantKey,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
self.layer = layer
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.head_size = layer.head_size
|
||||
self.quant_key = quant_key
|
||||
self.quant_dtype = quant_key.dtype
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {self.quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[self.quant_key]
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
def empty_quant(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
|
||||
return torch.empty(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
@staticmethod
|
||||
def remove_noop_permutes(gm: torch.fx.GraphModule) -> None:
|
||||
for node in gm.graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.permute.default):
|
||||
continue
|
||||
|
||||
dims = node.args[1]
|
||||
if any(dim != i for i, dim in enumerate(dims)):
|
||||
continue
|
||||
|
||||
# this is now an identity op, remove
|
||||
node.replace_all_uses_with(node.args[0])
|
||||
gm.graph.erase_node(node)
|
||||
|
||||
def register_if_supported(self, pm_pass: PatternMatcherPass) -> None:
|
||||
if self.layer.impl.fused_output_quant_supported(self.quant_key):
|
||||
self._register(pm_pass)
|
||||
|
||||
@abstractmethod
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Fp8StaticQuant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Fp8StaticQuant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
dtype: torch.dtype,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
|
||||
)
|
||||
super().__init__(layer, quant_key, dtype)
|
||||
self.quant_matcher = MatcherQuantFP8(quant_key)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
|
||||
return self.quant_matcher(attn_out_view, scale)[0]
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
inputs = [
|
||||
self.empty(5, self.num_heads, self.head_size), # q
|
||||
self.empty(5, self.num_heads, self.head_size), # k
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
"""
|
||||
Fusion for Attention+Nvfp4Quant.
|
||||
|
||||
Only triggers when the attention implementation returns True in
|
||||
`fused_output_quant_supported()`. If the pattern is found, the
|
||||
Nvfp4Quant op will be removed from the graph, and its scale
|
||||
will be passed into Attention op as the `output_scale` argument.
|
||||
"""
|
||||
|
||||
def __init__(self, layer: Attention, dtype: torch.dtype) -> None:
|
||||
super().__init__(layer, kNvfp4Dynamic, dtype)
|
||||
|
||||
def _register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
)
|
||||
at2 = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
output=output_quant,
|
||||
input=attn_out_view,
|
||||
output_scale=output_scale,
|
||||
input_scale=input_scale,
|
||||
is_sf_swizzled_layout=True,
|
||||
)
|
||||
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
|
||||
return at2[1], output_scale_view
|
||||
|
||||
def replacement(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
[q.shape[0], self.num_heads, self.head_size // 2],
|
||||
0.0,
|
||||
dtype=self.quant_dtype,
|
||||
device=q.device,
|
||||
)
|
||||
# attention output block scale
|
||||
output_scale_view = torch.ops.aten.view.dtype(output_scale, FP8_DTYPE)
|
||||
at2 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
output=output_attn,
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
|
||||
inputs = [
|
||||
empty_bf16(5, self.num_heads, self.head_size), # q
|
||||
empty_bf16(5, self.num_heads, self.head_size), # k
|
||||
empty_bf16(5, self.num_heads, self.head_size), # v
|
||||
empty_bf16(5, self.num_heads, self.head_size), # output_attn
|
||||
self.empty_quant(5, self.num_heads * self.head_size // 2), # output_quant
|
||||
empty_i32(
|
||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||
), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
AttentionQuantPattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
AttentionQuantPattern.fx_view_to_reshape,
|
||||
AttentionQuantPattern.remove_noop_permutes,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AttnFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses post-attention quantization onto attention if supported.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
Currently, only static fp8 quant is supported, but patterns could easily be
|
||||
added for other quant schemes and dtypes. The bigger hurdle for wider
|
||||
support are attention kernels, which need to support fusing output quant.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns = PatternMatcherPass(pass_name="attn_fusion_pass")
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for layer_name, layer in attn_layers.items():
|
||||
pattern_fp8 = AttentionFp8StaticQuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_fp8.register_if_supported(self.patterns)
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
pattern_nvfp4 = AttentionNvfp4QuantPattern(
|
||||
layer, config.model_config.dtype
|
||||
)
|
||||
pattern_nvfp4.register_if_supported(self.patterns)
|
||||
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning(
|
||||
"Attention + quant fusion is enabled, but no attention layers "
|
||||
"were found in CompilationConfig.static_forward_context "
|
||||
"so no fusion patterns were registered."
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.graph.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused quant onto %s attention nodes", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(
|
||||
self,
|
||||
AttentionQuantPattern,
|
||||
AttentionFp8StaticQuantPattern,
|
||||
AttentionNvfp4QuantPattern,
|
||||
)
|
||||
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
423
vllm/compilation/passes/fusion/collective_fusion.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BasePattern:
|
||||
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
|
||||
class GEMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mul = torch.empty([16, 4], device=self.device, dtype=self.dtype)
|
||||
mm_weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
return [mul, mm_weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
mm = torch.ops.aten.mm.default(mul, mm_weight)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(mul: torch.Tensor, mm_weight: torch.Tensor) -> torch.Tensor:
|
||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||
mul,
|
||||
mm_weight,
|
||||
"avg",
|
||||
scatter_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherGEMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
|
||||
return torch.ops.aten.mm.default(all_gather, weight)
|
||||
|
||||
def replacement(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
|
||||
x,
|
||||
[weight],
|
||||
gather_dim=0,
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class ScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
return [input, mm_weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
scaled_mm = torch.ops.aten._scaled_mm.default(
|
||||
input,
|
||||
mat2=mat2,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
scaled_mm,
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [x, weight, scale_a, scale_b]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
return torch.ops.aten._scaled_mm.default(
|
||||
all_gather,
|
||||
mat2=weight,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=None,
|
||||
scale_result=None,
|
||||
out_dtype=self.dtype,
|
||||
)
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
mm_weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
scale_a = torch.empty([16, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
cutlass_mm_output = torch.empty([16, 16], device=self.device, dtype=self.dtype)
|
||||
return [input, mm_weight, scale_a, scale_b, cutlass_mm_output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=cutlass_mm_output,
|
||||
a=input,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
|
||||
reduce_scatter = torch.ops.vllm.reduce_scatter.default(
|
||||
cutlass_scaled_mm[1],
|
||||
dim=0,
|
||||
world_size=self.tp_size,
|
||||
group_name=self.tp.unique_name,
|
||||
)
|
||||
return reduce_scatter
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
cutlass_mm_output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
scale_b,
|
||||
"avg",
|
||||
scatter_dim, # orig_scatter_dim
|
||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||
self.tp.device_group.group_name,
|
||||
output_shape,
|
||||
None, # bias
|
||||
None, # result_scale
|
||||
self.dtype, # out_dtype
|
||||
False, # use_fast_accum
|
||||
)
|
||||
|
||||
return gemm_rs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AllGatherCutlassScaledMMPattern(BasePattern):
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
x = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
weight = (
|
||||
torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE)
|
||||
.contiguous()
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
s1 = x.shape[0] * self.tp_size
|
||||
|
||||
scale_a = torch.empty([s1, 1], device=self.device, dtype=torch.float32)
|
||||
scale_b = torch.empty([1, 16], device=self.device, dtype=torch.float32)
|
||||
|
||||
s2 = weight.shape[1]
|
||||
output = torch.empty([s1, s2], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [x, weight, scale_a, scale_b, output]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
all_gather = torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp.unique_name
|
||||
)
|
||||
|
||||
cutlass_scaled_mm = torch.ops.higher_order.auto_functionalized(
|
||||
torch.ops._C.cutlass_scaled_mm.default,
|
||||
out=output,
|
||||
a=all_gather,
|
||||
b=weight,
|
||||
a_scales=scale_a,
|
||||
b_scales=scale_b,
|
||||
bias=None,
|
||||
)
|
||||
return cutlass_scaled_mm[1]
|
||||
|
||||
def replacement(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_scaled_matmul( # noqa
|
||||
x,
|
||||
[weight],
|
||||
scale_a,
|
||||
[scale_b],
|
||||
gather_dim=0,
|
||||
biases=[None],
|
||||
result_scales=[None],
|
||||
out_dtypes=[self.dtype],
|
||||
use_fast_accum=[False],
|
||||
group_name=self.tp.device_group.group_name,
|
||||
)
|
||||
return mm_outputs
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AsyncTPPass(VllmPatternMatcherPass):
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Enable symmetric memory for the TP process group
|
||||
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="async_tp_pass"
|
||||
)
|
||||
GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns)
|
||||
|
||||
# These fusions are enabled only for bfloat16 models because
|
||||
# `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling
|
||||
# only supports bfloat16 as the output dtype.
|
||||
if self.model_dtype == torch.bfloat16:
|
||||
ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
472
vllm/compilation/passes/fusion/matcher_utils.py
Normal file
472
vllm/compilation/passes/fusion/matcher_utils.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
_normalize_quant_group_shape,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
ROTARY_OP = torch.ops._C.rotary_embedding.default
|
||||
FLASHINFER_ROTARY_OP = torch.ops.vllm.flashinfer_rotary_embedding.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
||||
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
|
||||
|
||||
|
||||
class MatcherCustomOp(ABC):
|
||||
def __init__(self, enabled: bool) -> None:
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device = config.device_config.device if config.device_config else None
|
||||
|
||||
self.enabled = enabled
|
||||
self.forward = self.forward_custom if enabled else self.forward_native
|
||||
|
||||
@abstractmethod
|
||||
def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_native(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)
|
||||
|
||||
def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)
|
||||
|
||||
def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
"""Utility for inputs to the pattern"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
is_neox: bool,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
use_flashinfer: bool = False,
|
||||
match_rocm_aiter: bool | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RotaryEmbedding.enabled()
|
||||
if match_rocm_aiter is None:
|
||||
match_rocm_aiter = rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.is_neox = is_neox
|
||||
self.head_size = head_size
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.kv_size = self.num_kv_heads * self.head_size
|
||||
self.rotary_dim = head_size
|
||||
if use_flashinfer:
|
||||
self.rotary_op = FLASHINFER_ROTARY_OP
|
||||
elif match_rocm_aiter:
|
||||
self.rotary_op = rocm_aiter_ops.get_triton_rotary_embedding_op()
|
||||
else:
|
||||
self.rotary_op = ROTARY_OP
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
positions = self.empty_int64(5)
|
||||
query = self.empty(5, self.q_size)
|
||||
key = self.empty(5, self.kv_size)
|
||||
cos_sin_cache = self.empty(4096, self.rotary_dim)
|
||||
return [positions, query, key, cos_sin_cache]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result = auto_functionalized(
|
||||
self.rotary_op,
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=self.head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
)
|
||||
query_out = result[1]
|
||||
key_out = result[2] if len(result) > 2 else None
|
||||
return query_out, key_out
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
result: tuple[torch.Tensor, torch.Tensor | None] = (
|
||||
RotaryEmbedding.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self._rmsnorm_op = RMS_OP
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_op()
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
return [input, weight]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self._rmsnorm_op(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight)
|
||||
|
||||
result = torch.empty_like(input)
|
||||
_, result = auto_functionalized(
|
||||
self._rmsnorm_op,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
||||
)
|
||||
|
||||
|
||||
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
enabled: bool | None = None,
|
||||
match_rocm_aiter: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = RMSNorm.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.epsilon = epsilon
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
|
||||
self._rmsnorm_op = RMS_ADD_OP
|
||||
|
||||
if match_rocm_aiter:
|
||||
self._rmsnorm_op = rocm_aiter_ops.get_rmsnorm_fused_add_op()
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
||||
weight = self.empty(16)
|
||||
residual = self.empty(5, 16)
|
||||
return [input, weight, residual]
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self._rmsnorm_op( # type: ignore[no-any-return]
|
||||
x=input, residual=residual, weight=weight, variance_epsilon=self.epsilon
|
||||
)
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, weight, residual)
|
||||
|
||||
_, result, residual = auto_functionalized(
|
||||
self._rmsnorm_op,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
return result, residual
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
quant_key: QuantKey,
|
||||
enabled: bool | None = None,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
match_rocm_aiter: bool = False,
|
||||
is_tma_aligned: bool = False,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
enabled = QuantFP8.enabled()
|
||||
|
||||
super().__init__(enabled)
|
||||
self.quant_key = quant_key
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.match_rocm_aiter = match_rocm_aiter
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
|
||||
if match_rocm_aiter:
|
||||
assert not quant_key.scale.group_shape.is_per_tensor(), (
|
||||
"ROCm aiter fusion pass does not support per tensor quantization"
|
||||
)
|
||||
if quant_key.scale.group_shape.is_per_token():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_per_token_quant_op()
|
||||
else:
|
||||
assert quant_key.scale.group_shape.col == 128, (
|
||||
"ROCm aiter fusion pass currently supports "
|
||||
"quantization operation with group_size 128"
|
||||
)
|
||||
if current_platform.is_fp8_fnuz():
|
||||
self.QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
else:
|
||||
self.QUANT_OP = (
|
||||
torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
)
|
||||
|
||||
else:
|
||||
assert quant_key in QUANT_OPS, (
|
||||
f"unsupported quantization scheme {quant_key}"
|
||||
)
|
||||
self.QUANT_OP = QUANT_OPS[quant_key]
|
||||
|
||||
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
||||
"Only QuantFP8 supported by"
|
||||
)
|
||||
assert quant_key.scale2 is None
|
||||
|
||||
self.quant_fp8 = QuantFP8(
|
||||
quant_key.scale.static,
|
||||
quant_key.scale.group_shape,
|
||||
column_major_scales=has_col_major_scales,
|
||||
use_ue8m0=is_e8m0,
|
||||
tma_aligned_scales=self.is_tma_aligned,
|
||||
compile_native=False,
|
||||
)
|
||||
|
||||
def forward_rocm_aiter(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
quant_key_group_shape = self.quant_key.scale.group_shape
|
||||
if quant_key_group_shape == GroupShape.PER_TOKEN:
|
||||
return self.QUANT_OP( # type: ignore[no-any-return]
|
||||
x=input,
|
||||
quant_dtype=self.quant_key.dtype,
|
||||
scale=scale,
|
||||
)
|
||||
else:
|
||||
return self.QUANT_OP(input, quant_key_group_shape.col) # type: ignore[no-any-return]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.match_rocm_aiter:
|
||||
return self.forward_rocm_aiter(input, scale)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_key.dtype
|
||||
)
|
||||
|
||||
if self.quant_key.scale.group_shape.is_per_group():
|
||||
# for tma_aligned, the scale must be passed to forward_custom
|
||||
# tma_aligned fusion then matches by custom op arguments
|
||||
if not self.is_tma_aligned:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input, transposed=self.has_col_major_scales)
|
||||
|
||||
finfo = torch.finfo(self.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP,
|
||||
input=input,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
return result, scale
|
||||
|
||||
if self.quant_key.scale.static:
|
||||
assert scale is not None
|
||||
_, result = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale
|
||||
)
|
||||
return result, scale
|
||||
else:
|
||||
assert scale is None
|
||||
scale = self.make_scale(input)
|
||||
_, result, scale = auto_functionalized(
|
||||
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
|
||||
)
|
||||
return result, scale
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
scale: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.quant_fp8(input, scale) # type: ignore[no-any-return]
|
||||
|
||||
def make_scale(self, input: torch.Tensor, transposed: bool = False) -> torch.Tensor:
|
||||
normalized_group_shape = _normalize_quant_group_shape(
|
||||
input, self.quant_key.scale.group_shape
|
||||
)
|
||||
scale_shape = (
|
||||
input.shape[0] // normalized_group_shape[0],
|
||||
input.shape[1] // normalized_group_shape[1],
|
||||
)
|
||||
if transposed:
|
||||
scale_shape = tuple(reversed(scale_shape))
|
||||
return torch.empty(
|
||||
scale_shape, device=input.device, dtype=torch.float32
|
||||
).permute(-1, -2)
|
||||
|
||||
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 16)
|
||||
if self.quant_key.scale.static:
|
||||
return [input, self.empty_f32(1, 1)]
|
||||
|
||||
return [input]
|
||||
|
||||
|
||||
class MatcherSiluAndMul(MatcherCustomOp):
|
||||
def __init__(self, enabled: bool | None = None) -> None:
|
||||
if enabled is None:
|
||||
enabled = SiluAndMul.enabled()
|
||||
super().__init__(enabled)
|
||||
|
||||
def inputs(self) -> list[torch.Tensor]:
|
||||
input = self.empty(5, 4)
|
||||
return [input]
|
||||
|
||||
def forward_custom(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
|
||||
return result[1]
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return SiluAndMul.forward_native(x)
|
||||
244
vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Normal file
244
vllm/compilation/passes/fusion/qk_norm_rope_fusion.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import ParamSpec
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherRMSNorm, MatcherRotaryEmbedding
|
||||
from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FUSED_QK_ROPE_OP = torch.ops._C.fused_qk_norm_rope.default
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class QkNormRopePattern:
|
||||
"""
|
||||
Match the unfused sequence in attention blocks and replace with the fused op.
|
||||
|
||||
Unfused (conceptually):
|
||||
q, k, v = split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
qh = reshape(q, [-1, num_heads, head_dim])
|
||||
kh = reshape(k, [-1, num_kv_heads, head_dim])
|
||||
qn = rms_norm(qh, q_weight, eps)
|
||||
kn = rms_norm(kh, k_weight, eps)
|
||||
qf = reshape(qn, [-1, num_heads * head_dim])
|
||||
kf = reshape(kn, [-1, num_kv_heads * head_dim])
|
||||
qf, kf = rotary_embedding(positions, qf, kf, head_dim, cos_sin_cache, is_neox)
|
||||
return qf, kf, v
|
||||
|
||||
Fused replacement:
|
||||
fused_qk_norm_rope(qkv, num_heads, num_kv_heads, num_kv_heads, head_dim,
|
||||
eps, q_weight, k_weight, cos_sin_cache, is_neox,
|
||||
positions.view(-1))
|
||||
return split(qkv, [qsz, kvsz, kvsz], -1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
eps: float,
|
||||
is_neox: bool,
|
||||
rope_flashinfer: bool = False,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.eps = eps
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(eps)
|
||||
self.is_neox = is_neox
|
||||
self.rope_flashinfer = rope_flashinfer
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=is_neox,
|
||||
head_size=self.head_dim,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
use_flashinfer=self.rope_flashinfer,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
qkv = empty_bf16(T, self.q_size + 2 * self.kv_size)
|
||||
positions = empty_i64(T)
|
||||
q_weight = empty_bf16(1, self.head_dim)
|
||||
k_weight = empty_bf16(1, self.head_dim)
|
||||
if self.rope_flashinfer:
|
||||
cos_sin_cache = empty_fp32(4096, self.head_dim)
|
||||
else:
|
||||
cos_sin_cache = empty_bf16(4096, self.head_dim)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
q_weight,
|
||||
k_weight,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def wrap_trace_fn(
|
||||
trace_fn: Callable[P, fx.GraphModule],
|
||||
*process_fx_fns: Callable[[fx.GraphModule], None],
|
||||
) -> Callable[P, fx.GraphModule]:
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> fx.GraphModule:
|
||||
gm = trace_fn(*args, **kwargs)
|
||||
for process_fx in process_fx_fns:
|
||||
process_fx(gm)
|
||||
|
||||
return gm
|
||||
|
||||
return wrapped
|
||||
|
||||
@staticmethod
|
||||
def fx_view_to_reshape(gm: torch.fx.GraphModule) -> None:
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
|
||||
view_to_reshape(gm)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# split qkv -> q,k,v
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
# Q path: view -> RMS -> view back to q.shape
|
||||
q_by_head = q.view(
|
||||
*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
q_normed_by_head = self.rmsnorm_matcher(q_by_head, q_weight)
|
||||
q_flat = q_normed_by_head.view(q.shape)
|
||||
|
||||
# K path: view -> RMS -> view back to k.shape
|
||||
k_by_head = k.view(
|
||||
*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim
|
||||
)
|
||||
k_normed_by_head = self.rmsnorm_matcher(k_by_head, k_weight)
|
||||
k_flat = k_normed_by_head.view(k.shape)
|
||||
|
||||
# RoPE: apply to flattened q/k
|
||||
q_rope, k_rope = self.rope_matcher(positions, q_flat, k_flat, cos_sin_cache)
|
||||
return q_rope, k_rope, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
q_weight: torch.Tensor,
|
||||
k_weight: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# Run fused qk_norm_rope op
|
||||
result = auto_functionalized(
|
||||
FUSED_QK_ROPE_OP,
|
||||
qkv=qkv,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=self.num_kv_heads,
|
||||
num_heads_v=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
eps=self.eps,
|
||||
q_weight=q_weight,
|
||||
k_weight=k_weight,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
position_ids=positions.view(-1),
|
||||
)
|
||||
result_qkv = result[1]
|
||||
|
||||
# Split back to q,k,v and return
|
||||
return result_qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # type: ignore[no-any-return]
|
||||
|
||||
# NOTE: use fx_view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.get_inputs(),
|
||||
QkNormRopePattern.wrap_trace_fn(
|
||||
pm.fwd_only,
|
||||
QkNormRopePattern.fx_view_to_reshape,
|
||||
),
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class QKNormRoPEFusionPass(VllmPatternMatcherPass):
|
||||
"""Fuse Q/K RMSNorm + RoPE into fused_qk_norm_rope when the custom op exists."""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="qk_norm_rope_fusion_pass"
|
||||
)
|
||||
|
||||
dtype = config.model_config.dtype
|
||||
if dtype not in (torch.bfloat16, torch.float16):
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion not enabled: unsupported dtype %s", dtype
|
||||
)
|
||||
return
|
||||
|
||||
# use one attn layer to get meta (such as head_dim) for QkNormRopePattern
|
||||
attn_layers: dict[str, Attention] = get_layers_from_vllm_config(
|
||||
config, Attention
|
||||
)
|
||||
if len(attn_layers) == 0:
|
||||
logger.warning_once(
|
||||
"QK Norm+RoPE fusion enabled, but no Attention layers were discovered."
|
||||
)
|
||||
return
|
||||
layer = next(iter(attn_layers.values()))
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for neox in [True, False]:
|
||||
if RotaryEmbedding.enabled():
|
||||
for rope_flashinfer in [False, True]:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
rope_flashinfer=rope_flashinfer,
|
||||
).register(self.patterns)
|
||||
else:
|
||||
QkNormRopePattern(
|
||||
head_dim=layer.head_size,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
eps=epsilon,
|
||||
is_neox=neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Fused QK Norm+RoPE on %s sites", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, QkNormRopePattern)
|
||||
643
vllm/compilation/passes/fusion/rms_quant_fusion.py
Normal file
643
vllm/compilation/passes/fusion/rms_quant_fusion.py
Normal file
@@ -0,0 +1,643 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
kFp8Dynamic64Sym,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8DynamicTensorSym,
|
||||
kFp8DynamicTokenSym,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Dynamic,
|
||||
kStaticTensorScale,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def empty_bf16(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
def empty_i32(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
|
||||
|
||||
|
||||
def empty_i64(*args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
return torch.empty(*args, **kwargs, dtype=torch.int64, device="cuda")
|
||||
|
||||
|
||||
RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
||||
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
||||
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
||||
}
|
||||
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
||||
QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.default
|
||||
if current_platform.is_cuda():
|
||||
QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501
|
||||
|
||||
|
||||
class FusedRMSQuantKey(NamedTuple):
|
||||
"""
|
||||
Named tuple for identifying the type of RMSNorm + quant fusion.
|
||||
quant: type of quantization
|
||||
fused_add: does the op also perform the residual add
|
||||
"""
|
||||
|
||||
quant: QuantKey
|
||||
fused_add: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"FusedQuantKey({self.quant}, with"
|
||||
f"{'' if self.fused_add else 'out'} residual)"
|
||||
)
|
||||
|
||||
|
||||
FUSED_OPS: dict[FusedRMSQuantKey, OpOverload] = {
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, False
|
||||
): torch.ops._C.rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8StaticTensorSym, True
|
||||
): torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, False
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8DynamicTokenSym, True
|
||||
): torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic128Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, False
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
FusedRMSQuantKey(
|
||||
kFp8Dynamic64Sym, True
|
||||
): torch.ops._C.rms_norm_per_block_quant.default, # noqa: E501
|
||||
}
|
||||
|
||||
|
||||
class RMSNormQuantPattern:
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
key: FusedRMSQuantKey,
|
||||
has_col_major_scales: bool = False,
|
||||
is_e8m0: bool = False,
|
||||
is_tma_aligned: bool = False,
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
config = get_current_vllm_config()
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
|
||||
assert key in FUSED_OPS, f"unsupported fused rmsnorm+quant op for {key}"
|
||||
self.FUSED_OP = FUSED_OPS[key]
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
fused_key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, fused_key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
# Cannot use methods, as the self argument affects tracing
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
return self.quant_matcher(result_rms, scale)[0]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty(
|
||||
input.shape, device=input.device, dtype=self.quant_dtype
|
||||
)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
inputs = [
|
||||
# input, weight
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
pattern(*inputs)
|
||||
|
||||
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||
|
||||
|
||||
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self, epsilon: float, quant_dtype: torch.dtype, symmetric: bool = True
|
||||
) -> None:
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(
|
||||
dtype=quant_dtype, scale=kStaticTensorScale, symmetric=symmetric
|
||||
),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, _ = self.quant_matcher(result_rms, scale)
|
||||
|
||||
return result, residual
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
)
|
||||
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
inputs = [
|
||||
# input, weight, residual
|
||||
*self.rmsnorm_matcher.inputs(),
|
||||
self.quant_matcher.inputs()[1], # scale
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
inputs,
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric: bool = True,
|
||||
is_e8m0: bool = False,
|
||||
has_col_major_scales: bool = True,
|
||||
is_tma_aligned: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
self.is_e8m0 = is_e8m0
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
super().__init__(
|
||||
epsilon,
|
||||
key,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result = torch.empty(
|
||||
result_rms.shape,
|
||||
device=result_rms.device,
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
input=result_rms,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.quant_matcher.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs() + [scale],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGroupQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
symmetric: bool = True,
|
||||
is_e8m0: bool = False,
|
||||
has_col_major_scales: bool = True,
|
||||
is_tma_aligned: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
self.group_shape = group_shape
|
||||
self.has_col_major_scales = has_col_major_scales
|
||||
self.is_tma_aligned = is_tma_aligned
|
||||
super().__init__(
|
||||
epsilon,
|
||||
key,
|
||||
has_col_major_scales=self.has_col_major_scales,
|
||||
is_e8m0=is_e8m0,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result = torch.empty(
|
||||
result_rms.shape,
|
||||
device=result_rms.device,
|
||||
dtype=self.quant_matcher.quant_key.dtype,
|
||||
)
|
||||
assert scale is not None
|
||||
finfo = torch.finfo(self.quant_matcher.quant_key.dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
_, result, scale = auto_functionalized(
|
||||
self.quant_matcher.QUANT_OP,
|
||||
input=result_rms,
|
||||
output_q=result,
|
||||
output_s=scale,
|
||||
group_size=self.quant_matcher.quant_key.scale.group_shape[1],
|
||||
eps=1e-10,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
scale_ue8m0=self.quant_matcher.is_e8m0,
|
||||
dummy_is_scale_transposed=self.has_col_major_scales,
|
||||
dummy_is_tma_aligned=self.is_tma_aligned,
|
||||
)
|
||||
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
group_size=self.group_shape[1],
|
||||
is_scale_transposed=self.has_col_major_scales,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
scale = self.quant_matcher.empty_f32(1, 1)
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs() + [scale],
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
# result, scale
|
||||
return self.quant_matcher(result_rms) # type: ignore[no-any-return]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=None,
|
||||
)
|
||||
|
||||
# result, scale
|
||||
return at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
super().__init__(epsilon, key)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# In case we're matching native rms-norm, conversions might be
|
||||
# optimized out. We convert here just to be safe.
|
||||
input = input.to(dtype=self.model_dtype)
|
||||
|
||||
result = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
scale = self.quant_matcher.make_scale(input)
|
||||
at = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=self.epsilon,
|
||||
scale_ub=None,
|
||||
residual=residual,
|
||||
)
|
||||
|
||||
# result, residual, scale
|
||||
return at[1], at[3], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rmsnorm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse fused_add_rms_norm + static fp8 quant
|
||||
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + static fp8 quant
|
||||
RMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||
FusedAddRMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(
|
||||
self.patterns
|
||||
)
|
||||
|
||||
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE).register(self.patterns)
|
||||
|
||||
# Only register group quant patterns on CUDA where the C++ op exists
|
||||
if current_platform.is_cuda():
|
||||
for group_shape in [GroupShape(1, 128), GroupShape(1, 64)]:
|
||||
for has_col_major_scales in [True, False]:
|
||||
for is_e8m0 in [True, False]:
|
||||
for is_tma_aligned in [False, True]:
|
||||
# Fuse fused_add_rms_norm + fp8 group quant
|
||||
FusedAddRMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse rms_norm + fp8 group quant
|
||||
RMSNormGroupQuantPattern(
|
||||
epsilon,
|
||||
FP8_DTYPE,
|
||||
group_shape=group_shape,
|
||||
is_e8m0=is_e8m0,
|
||||
has_col_major_scales=has_col_major_scales,
|
||||
is_tma_aligned=is_tma_aligned,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return self.hash_source(
|
||||
self,
|
||||
RMSNormGroupQuantPattern,
|
||||
RMSNormQuantPattern,
|
||||
RMSNormStaticQuantPattern,
|
||||
RMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormStaticQuantPattern,
|
||||
FusedAddRMSNormDynamicQuantPattern,
|
||||
FusedAddRMSNormGroupQuantPattern,
|
||||
)
|
||||
504
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Normal file
504
vllm/compilation/passes/fusion/rocm_aiter_fusion.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
from torch._ops import OpOverload
|
||||
|
||||
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .act_quant_fusion import ActivationQuantPattern
|
||||
from .matcher_utils import (
|
||||
MatcherFusedAddRMSNorm,
|
||||
MatcherQuantFP8,
|
||||
MatcherRMSNorm,
|
||||
MatcherSiluAndMul,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
FusedRMSQuantKey,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class AiterRMSNormQuantPattern:
|
||||
def __init__(
|
||||
self, epsilon: float, key: FusedRMSQuantKey, match_aiter_quant: bool = True
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.quant_dtype = key.quant.dtype
|
||||
|
||||
self.rmsnorm_matcher = (
|
||||
MatcherRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
if not key.fused_add
|
||||
else MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
)
|
||||
self.quant_matcher = MatcherQuantFP8(
|
||||
key.quant,
|
||||
match_rocm_aiter=match_aiter_quant,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSNormDynamicQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""AITER RMSNorm Fused Add + Dynamic Quantization pattern."""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_fused_add_dynamic_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
match_aiter_quant: bool = True,
|
||||
group_shape: GroupShape = GroupShape.PER_TOKEN,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor, weight: torch.Tensor, residual: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=self.epsilon,
|
||||
quant_dtype=self.quant_dtype,
|
||||
)
|
||||
|
||||
return result[0], result[1], result[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern,
|
||||
replacement,
|
||||
self.rmsnorm_matcher.inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class AiterRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||
ops into an aiter rms_norm_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=False,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_rms = self.rmsnorm_matcher(input, weight)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
return result, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
return at[0], at[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class AiterFusedAddRMSFp8GroupQuantPattern(AiterRMSNormQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_OP = rocm_aiter_ops.get_rmsnorm_group_add_fused_quant_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
quant_dtype: torch.dtype,
|
||||
group_shape: GroupShape,
|
||||
match_aiter_quant: bool = True,
|
||||
symmetric: bool = True,
|
||||
) -> None:
|
||||
scale = ScaleDesc(torch.float32, False, group_shape)
|
||||
key = FusedRMSQuantKey(
|
||||
fused_add=True,
|
||||
quant=QuantKey(dtype=quant_dtype, scale=scale, symmetric=symmetric),
|
||||
)
|
||||
|
||||
super().__init__(epsilon, key, match_aiter_quant)
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
result, scale = self.quant_matcher(result_rms)
|
||||
|
||||
return result, residual_out, scale
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_OP(
|
||||
x=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
group_size=128,
|
||||
)
|
||||
|
||||
# result, scale, residual
|
||||
return at[0], at[1], at[2]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.rmsnorm_matcher.inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses aiter rms_norm & vllm/aiter quant custom ops
|
||||
into a fused rms_norm_quant op.
|
||||
It also supports fused_add_rms_norm.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_rms_norm_quant_fusion_pass"
|
||||
)
|
||||
|
||||
# Make sure fused add patterns are before simple rms norm,
|
||||
# as the latter is a subset of the former in torch ops
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# Fuse aiter rms_norm + aiter dynamic group fp8 quant
|
||||
AiterRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse aiter fused_add_rms_norm + aiter dynamic group fp8 quant
|
||||
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||
epsilon, FP8_DTYPE, GroupShape(1, 128)
|
||||
).register(self.patterns)
|
||||
|
||||
for match_aiter_quant in [True, False]:
|
||||
# Fuse aiter rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
# Fuse aiter fused_add_rms_norm + (aiter / vllm built-in)
|
||||
# dynamic per-token fp8 quant
|
||||
AiterFusedAddRMSNormDynamicQuantPattern(
|
||||
epsilon, FP8_DTYPE, match_aiter_quant=match_aiter_quant
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
AiterRMSNormDynamicQuantPattern,
|
||||
AiterFusedAddRMSNormDynamicQuantPattern,
|
||||
AiterRMSFp8GroupQuantPattern,
|
||||
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||
]
|
||||
return self.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||
"""
|
||||
This pattern fuses aiter silu_and_mul & group fp8 quant custom
|
||||
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||
"""
|
||||
|
||||
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
|
||||
|
||||
def __init__(self, quant_op: OpOverload) -> None:
|
||||
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||
self.quant_op = quant_op
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
return [
|
||||
self.silu_and_mul_matcher.inputs()[0],
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = self.silu_and_mul_matcher(input)
|
||||
at2 = self.quant_op(at1, 128)
|
||||
return at2[0], at2[1]
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at = self.FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||
return at[0], at[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
|
||||
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||
|
||||
QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||
)
|
||||
|
||||
for quant_op in self.QUANT_OPS:
|
||||
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
fusion_patterns = [
|
||||
ActivationQuantPattern,
|
||||
AiterSiluMulFp8GroupQuantPattern,
|
||||
]
|
||||
return VllmInductorPass.hash_source(self, *fusion_patterns)
|
||||
|
||||
|
||||
class AddAiterRMSNormPadPattern:
|
||||
"""
|
||||
This pattern replaces an aiter_rmsnorm_with_add & a pad op
|
||||
with a custom triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
AITER_TRITON_ADD_RMSNORM_PAD_OP = rocm_aiter_ops.get_triton_add_rmsnorm_pad_op()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
hidden_size: int,
|
||||
x_pad_to_multiple: int,
|
||||
):
|
||||
self.epsilon = epsilon
|
||||
self.hidden_size = hidden_size
|
||||
self.x_pad_to_multiple = x_pad_to_multiple
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=True)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input, weight, residual = self.rmsnorm_matcher.inputs()
|
||||
router_weight = torch.empty([8, 16], dtype=weight.dtype, device=weight.device)
|
||||
router_bias = torch.empty([8], dtype=weight.dtype, device=weight.device)
|
||||
return [input, weight, residual, router_weight, router_bias]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pad_size = self.x_pad_to_multiple - (
|
||||
self.hidden_size % self.x_pad_to_multiple
|
||||
)
|
||||
result_rms, residual_out = self.rmsnorm_matcher(input, weight, residual)
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_rms, router_weight, router_bias
|
||||
)
|
||||
result = torch.nn.functional.pad(
|
||||
result_rms, (0, pad_size), mode="constant", value=0.0
|
||||
)
|
||||
return result, residual_out, router_logits
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
router_weight: torch.Tensor,
|
||||
router_bias: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
at = self.AITER_TRITON_ADD_RMSNORM_PAD_OP(
|
||||
x=input,
|
||||
weight=weight,
|
||||
variance_epsilon=self.epsilon,
|
||||
residual=residual,
|
||||
x_pad_to_multiple=self.x_pad_to_multiple,
|
||||
)
|
||||
result_padded = at[0]
|
||||
router_logits = torch.ops.vllm.rocm_unquantized_gemm(
|
||||
result_padded[:, : self.hidden_size], router_weight, router_bias
|
||||
)
|
||||
residual_out = at[1]
|
||||
return result_padded, residual_out, router_logits
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RocmAiterTritonAddRMSNormPadFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass replaces an AITER CK RMSNorm + residual add and a pad op
|
||||
with an triton_add_rmsnorm_pad op from AITER.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
super().__init__(config)
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rocm_aiter_triton_add_rmsnorm_pad_fusion_pass"
|
||||
)
|
||||
|
||||
# gpt-oss has hidden size 2880
|
||||
# padded to a multiple of 128 on gfx942 and 256 on gfx950 respectively
|
||||
hidden_size = 2880
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
for x_pad_to_multiple in [128, 256]:
|
||||
AddAiterRMSNormPadPattern(
|
||||
epsilon, hidden_size, x_pad_to_multiple
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, AddAiterRMSNormPadPattern)
|
||||
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
230
vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
from torch import fx
|
||||
from torch._higher_order_ops import auto_functionalized
|
||||
from torch._inductor.fx_passes.post_grad import view_to_reshape
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.attention import (
|
||||
Attention,
|
||||
get_attention_context,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import (
|
||||
MatcherRotaryEmbedding,
|
||||
)
|
||||
from .rms_quant_fusion import (
|
||||
empty_bf16,
|
||||
empty_i64,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_impl(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This impl fetches the KV cache and slot mapping from the forward context,
|
||||
then calls the layer impl's `AttentionImpl.do_rope_and_kv_cache_update` method.
|
||||
It also returns a dummy tensor, similar to `Attention.unified_kv_cache_update`,
|
||||
that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
attn_layer.impl.do_rope_and_kv_cache_update(
|
||||
attn_layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def fused_rope_and_unified_kv_cache_update_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
layer_name: str = "",
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=query.device, dtype=query.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_rope_and_unified_kv_cache_update",
|
||||
op_func=fused_rope_and_unified_kv_cache_update_impl,
|
||||
mutates_args=["query", "key"],
|
||||
fake_impl=fused_rope_and_unified_kv_cache_update_fake,
|
||||
)
|
||||
|
||||
|
||||
class RopeReshapeKVCachePattern:
|
||||
"""
|
||||
This pattern matches the following unfused inplace ops:
|
||||
q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox)
|
||||
kv_cache_dummy = unified_kv_cache_update(k, v, layer_name)
|
||||
|
||||
and replaces it with the fused inplace op:
|
||||
kv_cache_dummy = fused_rope_and_unified_kv_cache_update(
|
||||
q, k, v, positions, cos_sin_cache, is_neox, layer_name
|
||||
)
|
||||
"""
|
||||
|
||||
FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer: Attention,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
self.layer_name = layer.layer_name
|
||||
self.num_heads = layer.num_heads
|
||||
self.num_kv_heads = layer.num_kv_heads
|
||||
self.head_size = layer.head_size
|
||||
self.head_size_v = layer.head_size_v
|
||||
self.is_neox = is_neox
|
||||
|
||||
self.q_size = self.num_heads * self.head_size
|
||||
self.k_size = self.num_kv_heads * self.head_size
|
||||
self.v_size = self.num_kv_heads * self.head_size_v
|
||||
|
||||
self.rope_matcher = MatcherRotaryEmbedding(
|
||||
is_neox=self.is_neox,
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
# Sample inputs to help pattern tracing
|
||||
T = 5
|
||||
L = 4096
|
||||
qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size)
|
||||
positions = empty_i64(T)
|
||||
cos_sin_cache = empty_bf16(L, self.head_size)
|
||||
return [
|
||||
qkv,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q, k = self.rope_matcher(positions, q, k, cos_sin_cache)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
|
||||
return dummy, q, k, v
|
||||
|
||||
def replacement(
|
||||
qkv: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
|
||||
q = q.view(-1, self.num_heads, self.head_size)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_size)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
results = auto_functionalized(
|
||||
self.FUSED_OP,
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
positions=positions,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=self.is_neox,
|
||||
layer_name=self.layer_name,
|
||||
)
|
||||
return results[0], results[1], results[2], v
|
||||
|
||||
# NOTE: use view_to_reshape to unify view/reshape to simplify
|
||||
# pattern and increase matching opportunities
|
||||
def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule:
|
||||
gm = pm.fwd_only(*args, **kwargs)
|
||||
view_to_reshape(gm)
|
||||
return gm
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), fwd_and_view_to_reshape, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class RopeKVCacheFusionPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass fuses the rotary embedding and KV cache update operations
|
||||
into a single fused kernel if available.
|
||||
|
||||
It uses the pattern matcher and matches each layer manually, as strings
|
||||
cannot be wildcarded. This also lets us check support on attention layers
|
||||
upon registration instead of during pattern matching.
|
||||
|
||||
This fusion eliminates the need for separate kernel launches and
|
||||
intermediate memory operations between the RoPE and cache update steps.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="rope_kv_cache_fusion_pass"
|
||||
)
|
||||
|
||||
cc = config.compilation_config
|
||||
self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num
|
||||
|
||||
attn_layers = get_layers_from_vllm_config(config, Attention)
|
||||
for _, layer in attn_layers.items():
|
||||
if layer.impl.fused_rope_kvcache_supported():
|
||||
for is_neox in [True, False]:
|
||||
RopeReshapeKVCachePattern(
|
||||
layer=layer,
|
||||
is_neox=is_neox,
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass works best for the small-batch decode setting.
|
||||
# For large-batch e.g. prefill, it is better to use two separate kernels
|
||||
# since they are compute bound and the fused kernels require further tuning.
|
||||
return compile_range.end <= self.max_token_num
|
||||
|
||||
def uuid(self) -> str:
|
||||
return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern)
|
||||
452
vllm/compilation/passes/fusion/sequence_parallelism.py
Normal file
452
vllm/compilation/passes/fusion/sequence_parallelism.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._inductor.pattern_matcher as pm
|
||||
import torch.fx as fx
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..inductor_pass import enable_fake_mode
|
||||
from ..utility.noop_elimination import NoOpEliminationPass
|
||||
from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8, MatcherRMSNorm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Min hidden size per device capability for sequence parallelism
|
||||
# Only apply sequence parallelism for models with hidden_size >= threshold
|
||||
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
|
||||
90: 8192, # H100: only for models with hidden_size >= 8192
|
||||
}
|
||||
|
||||
# Min size per GPU per device capability for sequence parallelism
|
||||
# Total min size = min_per_gpu_size * tp_size
|
||||
# This ensures the threshold scales appropriately with tensor parallelism
|
||||
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
|
||||
90: 8, # 8MB per GPU for H100
|
||||
}
|
||||
|
||||
|
||||
def get_sequence_parallelism_threshold(
|
||||
hidden_size: int,
|
||||
tp_size: int,
|
||||
element_size: int,
|
||||
) -> int | None:
|
||||
"""
|
||||
Calculate the minimum token threshold for applying sequence parallelism.
|
||||
|
||||
Returns None if sequence parallelism should not be applied based on model size.
|
||||
|
||||
Branching logic based on device capability:
|
||||
- Check if hidden_size >= SP_MIN_HIDDEN_SIZE[device_capability]
|
||||
- If not, returns None (SP disabled for small models on this device)
|
||||
- If yes, calculates threshold based on per-GPU size
|
||||
|
||||
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
|
||||
(hidden_size * element_size)
|
||||
"""
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
return None
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
if capability is None:
|
||||
return None
|
||||
device_capability = capability.to_int()
|
||||
|
||||
# Check if device has configured thresholds
|
||||
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
|
||||
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
|
||||
|
||||
if min_hidden_size is None or min_per_gpu_size_mb is None:
|
||||
return None
|
||||
|
||||
# Only apply sequence parallelism for models meeting the size threshold
|
||||
if hidden_size < min_hidden_size:
|
||||
return None
|
||||
|
||||
MiB = 1024 * 1024
|
||||
min_size = min_per_gpu_size_mb * MiB * tp_size
|
||||
return int(min_size // (hidden_size * element_size))
|
||||
|
||||
|
||||
def get_first_out_wrapper(
|
||||
fn: Callable[..., Sequence[torch.Tensor]],
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any) -> torch.Tensor:
|
||||
return fn(*args)[0]
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class _SequenceParallelPatternHelper:
|
||||
"""Helper for sequence parallelism patterns."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
) -> None:
|
||||
self.epsilon = epsilon
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_group = get_tp_group()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _all_reduce(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
def _reduce_scatter(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.reduce_scatter.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
def _all_gather(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.all_gather.default(
|
||||
x, dim=0, world_size=self.tp_size, group_name=self.tp_group.unique_name
|
||||
)
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
arg3_1 = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [input, arg3_1]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, arg3_1)
|
||||
|
||||
return rmsnorm, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
arg3_1: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, arg3_1)
|
||||
all_gather = self._all_gather(rmsnorm)
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
|
||||
return [
|
||||
residual,
|
||||
mm_1,
|
||||
rms_norm_weights,
|
||||
]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rmsnorm = self.rmsnorm_matcher(all_reduce, rms_norm_weights, residual)
|
||||
return rmsnorm[0], rmsnorm[1]
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rmsnorm = self.rmsnorm_matcher(reduce_scatter, rms_norm_weights, residual)
|
||||
all_gather = self._all_gather(rmsnorm[0])
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, rmsnorm[1]
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(
|
||||
self,
|
||||
epsilon: float,
|
||||
dtype: torch.dtype,
|
||||
device: str | None,
|
||||
) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
input = torch.zeros([1, 8, 4], device=self.device, dtype=self.dtype)
|
||||
weight = torch.empty([4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
|
||||
return [input, weight, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(input)
|
||||
rms = self.rmsnorm_matcher(all_reduce, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, all_reduce
|
||||
|
||||
def replacement(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
reduce_scatter = self._reduce_scatter(input)
|
||||
rms = self.rmsnorm_matcher(reduce_scatter, weight)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
|
||||
return all_gather, reduce_scatter
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
|
||||
class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
|
||||
def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None:
|
||||
super().__init__(epsilon, dtype, device)
|
||||
self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon)
|
||||
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)
|
||||
|
||||
def get_inputs(self) -> list[torch.Tensor]:
|
||||
mm_1 = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
residual = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
rms_norm_weights = torch.empty([4, 4], device=self.device, dtype=self.dtype)
|
||||
scale = torch.empty([1, 1], device=self.device, dtype=torch.float32)
|
||||
|
||||
return [residual, mm_1, rms_norm_weights, scale]
|
||||
|
||||
def register(self, pm_pass: PatternMatcherPass) -> None:
|
||||
def pattern(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
all_reduce = self._all_reduce(mm_1)
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
all_reduce, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
return quant, residual_out
|
||||
|
||||
def replacement(
|
||||
residual: torch.Tensor,
|
||||
mm_1: torch.Tensor,
|
||||
rms_norm_weights: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# pattern matcher replaces from top-to-bottom,
|
||||
# so residual is still the full size here.
|
||||
# add a temporary slice which will become a noop
|
||||
# once the seqpar pattern with the previous rmsnorm is replaced
|
||||
reduce_scatter = self._reduce_scatter(mm_1)
|
||||
residual = residual[0 : reduce_scatter.size(0), ...]
|
||||
rms, residual_out = self.rmsnorm_matcher(
|
||||
reduce_scatter, rms_norm_weights, residual
|
||||
)
|
||||
quant, _ = self.quant_matcher(rms, scale)
|
||||
all_gather = self._all_gather(quant)
|
||||
# shape of residual changes but that's fine,
|
||||
# next node is already slicing it, now becomes a noop
|
||||
return all_gather, residual_out
|
||||
|
||||
pm.register_replacement(
|
||||
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
|
||||
)
|
||||
|
||||
pm.register_replacement(
|
||||
get_first_out_wrapper(pattern),
|
||||
get_first_out_wrapper(replacement),
|
||||
self.get_inputs(),
|
||||
pm.fwd_only,
|
||||
pm_pass,
|
||||
)
|
||||
|
||||
|
||||
class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
"""
|
||||
This pass enables sequence parallelism for models.
|
||||
It identifies patterns where an AllReduce operation is followed by
|
||||
an RMSNorm (or RMSNorm and then Quantization) operation.
|
||||
These patterns are replaced with a ReduceScatter operation, followed by
|
||||
a local RMSNorm/Quantization, and then an AllGather operation.
|
||||
|
||||
The general transformation is:
|
||||
Input -> AllReduce -> RMSNorm -> Output
|
||||
becomes
|
||||
Input -> ReduceScatter -> RMSNorm -> AllGather -> Output
|
||||
|
||||
While this pass itself does not directly yield performance improvements,
|
||||
it lays the groundwork for subsequent fusion passes, such as
|
||||
GEMM + ReduceScatter and AllGather + GEMM fusions. These fusions can
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
|
||||
|
||||
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||
Because the pattern matcher starts at the end of the graph, the replacement
|
||||
contains a slice that temporarily conforms the input residual to the correct size.
|
||||
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||
what have now become no-op slices.
|
||||
|
||||
Note that an older version of the pass did not need this as it operated only on
|
||||
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||
mismatched shapes during replacement. So this approach has the same assumption that
|
||||
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||
|
||||
Correctness-wise, this is approach strictly better than before - before,
|
||||
the graph was incorrect semantically and shape-wise during the pass.
|
||||
With this approach there's only semantic incorrectness during the pass.
|
||||
Both approaches restore a correct graph once all patterns are matched.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
def __init__(self, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
# Get min_token_num threshold
|
||||
# Read min_token_num from config (calculated during config init)
|
||||
self.min_token_num = None
|
||||
if config.model_config is not None:
|
||||
pass_config = config.compilation_config.pass_config
|
||||
self.min_token_num = pass_config.sp_min_token_num
|
||||
|
||||
if self.min_token_num is not None:
|
||||
# Take the min to avoid exceeding max_num_batched_tokens
|
||||
max_batched = config.scheduler_config.max_num_batched_tokens
|
||||
if max_batched is not None:
|
||||
self.min_token_num = min(self.min_token_num, max_batched)
|
||||
logger.debug_once(
|
||||
f"Sequence parallelism min token threshold: {self.min_token_num}",
|
||||
scope="global",
|
||||
)
|
||||
|
||||
# Used to clean up redundant views created temporarily
|
||||
# to circumvent residual shape change issues
|
||||
self.noop_cleanup = NoOpEliminationPass(config)
|
||||
self.noop_cleanup.pass_name = f"{self.pass_name}.{self.noop_cleanup.pass_name}"
|
||||
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="sequence_parallelism_pass"
|
||||
)
|
||||
|
||||
for epsilon in [1e-5, 1e-6]:
|
||||
# RMSNorm + Static FP8 quantization patterns
|
||||
FirstAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
MiddleAllReduceRMSNormStaticFP8Pattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
# Normal RMSNorm patterns
|
||||
FirstAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
MiddleAllReduceRMSNormPattern(
|
||||
epsilon, self.model_dtype, self.device
|
||||
).register(self.patterns)
|
||||
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
"""
|
||||
Determines if sequence parallelism should be applied for the given
|
||||
compile range.
|
||||
|
||||
SP is only beneficial for larger batch sizes where the communication
|
||||
overhead is amortized. For small batches, the overhead of splitting
|
||||
and gathering tensors across TP ranks outweighs the benefits.
|
||||
|
||||
Returns False (SP disabled) when:
|
||||
- Using piecewise compilation with non-concrete or TP-indivisible sizes
|
||||
- min_token_num is None (SP disabled for this device/config)
|
||||
- The compile range starts below the minimum token threshold
|
||||
"""
|
||||
# For piecewise compilation (not using inductor graph partition),
|
||||
# we need concrete sizes that are divisible by TP for correct splitting
|
||||
if (
|
||||
not self.compilation_config.use_inductor_graph_partition
|
||||
and self.compilation_config.splitting_ops
|
||||
):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
|
||||
return False
|
||||
|
||||
# min_token_num is None when SP is disabled for this device/config
|
||||
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
|
||||
if self.min_token_num is None:
|
||||
return False
|
||||
|
||||
# Only apply SP when batch size meets the minimum threshold
|
||||
return compile_range.start >= self.min_token_num
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
self.matched_count = self.patterns.apply(graph)
|
||||
logger.debug("Replaced %s patterns", self.matched_count)
|
||||
# Clean up reshape nodes
|
||||
self.noop_cleanup(graph)
|
||||
77
vllm/compilation/passes/fx_utils.py
Normal file
77
vllm/compilation/passes/fx_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._ops import OpOverload, OpOverloadPacket
|
||||
from torch.fx.node import Target
|
||||
|
||||
|
||||
def is_func(node: fx.Node, target: Target) -> bool:
|
||||
return bool(node.op == "call_function" and node.target == target)
|
||||
|
||||
|
||||
def is_auto_func(node: fx.Node, op: OpOverload) -> bool:
|
||||
return is_func(node, auto_functionalized) and node.args[0] == op
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node | None:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: fx.Node, idx: int) -> fx.Node | None:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
# An auto-functionalization-aware utility for finding nodes with a specific op
|
||||
# Also handles op overload packets and finds all overloads
|
||||
def find_op_nodes(
|
||||
op: OpOverload | OpOverloadPacket, graph: fx.Graph
|
||||
) -> Iterator[fx.Node]:
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
for overload in op.overloads():
|
||||
overload_op = getattr(op, overload)
|
||||
yield from find_op_nodes(overload_op, graph)
|
||||
return
|
||||
|
||||
assert isinstance(op, OpOverload)
|
||||
|
||||
yield from graph.find_nodes(op="call_function", target=op)
|
||||
|
||||
for n in graph.find_nodes(op="call_function", target=auto_functionalized):
|
||||
if n.args[0] == op:
|
||||
yield n
|
||||
|
||||
|
||||
# Asserts that the node only has one user and returns it
|
||||
# Even if a node has only 1 user, it might share storage with another node,
|
||||
# which might need to be taken into account.
|
||||
def get_only_user(node: fx.Node) -> fx.Node:
|
||||
assert len(node.users) == 1
|
||||
return next(iter(node.users))
|
||||
134
vllm/compilation/passes/inductor_pass.py
Normal file
134
vllm/compilation/passes/inductor_pass.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.utils import Range
|
||||
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
|
||||
_pass_context = None
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class PassContext:
|
||||
def __init__(self, compile_range: Range):
|
||||
self.compile_range: Range = compile_range
|
||||
|
||||
|
||||
def get_pass_context() -> PassContext:
|
||||
"""Get the current pass context."""
|
||||
assert _pass_context is not None
|
||||
return _pass_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def pass_context(compile_range: Range) -> Generator[None, None, None]:
|
||||
"""A context manager that stores the current pass context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _pass_context
|
||||
prev_context = _pass_context
|
||||
_pass_context = PassContext(compile_range)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_pass_context = prev_context
|
||||
|
||||
|
||||
class InductorPass(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
@staticmethod
|
||||
def hash_source(*srcs: str | Any) -> str:
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, (types.FunctionType, type)):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
# object instance
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: dict[Any, Any]) -> str:
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, callable: Callable[[fx.Graph], None], uuid: Any | None = None
|
||||
) -> None:
|
||||
self.callable = callable
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
|
||||
def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Applies a FakeTensorMode context. This is useful when you don't want to
|
||||
create or run things with real tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def fn_new(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
return fn_new
|
||||
178
vllm/compilation/passes/pass_manager.py
Normal file
178
vllm/compilation/passes/pass_manager.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from torch import fx as fx
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import set_env_var
|
||||
|
||||
from .vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from .fusion.rocm_aiter_fusion import (
|
||||
RocmAiterRMSNormQuantFusionPass,
|
||||
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||
RocmAiterTritonAddRMSNormPadFusionPass,
|
||||
)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fusion.act_quant_fusion import ActivationQuantFusionPass
|
||||
from .fusion.attn_quant_fusion import AttnFusionPass
|
||||
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
|
||||
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
|
||||
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
|
||||
from .fusion.sequence_parallelism import SequenceParallelismPass
|
||||
from .utility.scatter_split_replace import ScatterSplitReplacementPass
|
||||
from .utility.split_coalescing import SplitCoalescingPass
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
|
||||
from .fusion.collective_fusion import AsyncTPPass
|
||||
|
||||
from .inductor_pass import (
|
||||
CustomGraphPass,
|
||||
InductorPass,
|
||||
get_pass_context,
|
||||
)
|
||||
from .utility.fix_functionalization import FixFunctionalizationPass
|
||||
from .utility.noop_elimination import NoOpEliminationPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Function decorator that turns on inductor pattern match debug
|
||||
for the duration of the call.
|
||||
Used to avoid logging builtin Inductor pattern matching.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None:
|
||||
# optionally check rank here
|
||||
with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val):
|
||||
return fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class PostGradPassManager(CustomGraphPass): # type: ignore[misc]
|
||||
"""
|
||||
The pass manager for post-grad passes.
|
||||
It handles configuration, adding custom passes, and running passes.
|
||||
It supports uuid for the Inductor code cache. That includes torch<2.6
|
||||
support using pickling (in .inductor_pass.CustomGraphPass).
|
||||
|
||||
The order of the post-grad post-passes is:
|
||||
1. passes (constructor parameter)
|
||||
2. default passes (NoopEliminationPass, FusionPass)
|
||||
3. config["post_grad_custom_post_pass"] (if it exists)
|
||||
4. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.passes: list[InductorPass] = []
|
||||
|
||||
@with_pattern_match_debug
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
VllmInductorPass.dump_prefix = 0 # reset dump index
|
||||
|
||||
compile_range = get_pass_context().compile_range
|
||||
for pass_ in self.passes:
|
||||
if pass_.is_applicable_for_range(compile_range):
|
||||
pass_(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
else:
|
||||
logger.debug("Skipping %s with compile range %s", pass_, compile_range)
|
||||
|
||||
# post-cleanup goes before fix_functionalization
|
||||
# because it requires a functional graph
|
||||
self.post_cleanup(graph)
|
||||
VllmInductorPass.dump_prefix += 1
|
||||
|
||||
# always run fix_functionalization last
|
||||
self.fix_functionalization(graph)
|
||||
VllmInductorPass.dump_prefix = None # Cleanup index
|
||||
|
||||
def configure(self, config: VllmConfig) -> None:
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
|
||||
# Set the current vllm config to allow tracing CustomOp instances
|
||||
with set_current_vllm_config(config, check_compile=False):
|
||||
if self.pass_config.eliminate_noops:
|
||||
self.passes += [NoOpEliminationPass(config)]
|
||||
|
||||
if self.pass_config.enable_sp:
|
||||
self.passes += [SequenceParallelismPass(config)]
|
||||
if self.pass_config.fuse_gemm_comms:
|
||||
self.passes += [AsyncTPPass(config)]
|
||||
|
||||
if self.pass_config.fuse_allreduce_rms:
|
||||
self.passes += [AllReduceFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_norm_quant:
|
||||
self.passes += [RMSNormQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [
|
||||
RocmAiterRMSNormQuantFusionPass(config),
|
||||
]
|
||||
if self.pass_config.fuse_act_quant:
|
||||
self.passes += [ActivationQuantFusionPass(config)]
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_act_padding and rocm_aiter_ops.is_enabled():
|
||||
self.passes += [RocmAiterTritonAddRMSNormPadFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_rope_kvcache:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [ScatterSplitReplacementPass(config)]
|
||||
self.passes += [RopeKVCacheFusionPass(config)]
|
||||
|
||||
if self.pass_config.fuse_attn_quant:
|
||||
self.passes += [AttnFusionPass(config)]
|
||||
|
||||
if self.pass_config.enable_qk_norm_rope_fusion:
|
||||
self.passes += [SplitCoalescingPass(config)]
|
||||
self.passes += [QKNormRoPEFusionPass(config)]
|
||||
|
||||
# needs a functional graph
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
def add(self, pass_: InductorPass) -> None:
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
|
||||
def uuid(self) -> str:
|
||||
"""
|
||||
The PostGradPassManager is set as a custom pass in the Inductor and
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
passes = []
|
||||
|
||||
state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()}
|
||||
for pass_ in self.passes:
|
||||
passes.append(pass_.uuid())
|
||||
passes.append(self.fix_functionalization.uuid())
|
||||
|
||||
# Include the compile range in the uuid to ensure that inductor
|
||||
# recompiles the graph for the new dynamic compile range.
|
||||
state["compile_range"] = str(get_pass_context().compile_range)
|
||||
state["passes"] = passes
|
||||
return InductorPass.hash_dict(state)
|
||||
0
vllm/compilation/passes/utility/__init__.py
Normal file
0
vllm/compilation/passes/utility/__init__.py
Normal file
301
vllm/compilation/passes/utility/fix_functionalization.py
Normal file
301
vllm/compilation/passes/utility/fix_functionalization.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import operator
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FixFunctionalizationPass(VllmInductorPass):
|
||||
"""
|
||||
This pass defunctionalizes certain nodes to avoid redundant tensor copies.
|
||||
After this pass, DCE (dead-code elimination) should never be run,
|
||||
as de-functionalized nodes may appear as dead code.
|
||||
|
||||
To add new nodes to defunctionalize, add to the if-elif chain in __call__.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
# XPU does not support auto-functionalization yet.
|
||||
# Will enable this when switch to vllm-xpu-kernels.
|
||||
if current_platform.is_xpu():
|
||||
logger.debug(
|
||||
"XPU platform does not support fix functionalizationpass currently."
|
||||
)
|
||||
return
|
||||
|
||||
self.nodes_to_remove: list[torch.fx.Node] = []
|
||||
count = 0
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue # Avoid deep if-elif nesting
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target == torch.ops._C.rotary_embedding.default:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = self.getitem_users(node)
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of an mm_node.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into mm_node. So after de-functionalization, we can
|
||||
# just use mm_node directly.
|
||||
|
||||
mm_node = query.args[0].args[0]
|
||||
for user in getitem_nodes.values():
|
||||
for user_of_getitem in user.users:
|
||||
if is_func(
|
||||
user_of_getitem, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
user_of_getitem.replace_all_uses_with(mm_node)
|
||||
self._remove(user_of_getitem)
|
||||
self._remove(user)
|
||||
|
||||
self.insert_defunctionalized(graph, node)
|
||||
self._remove(node)
|
||||
|
||||
else:
|
||||
# Directly replace the auto_functionalize(rotary_embedding)
|
||||
# with the inplace rotary_embedding. In theory, we shouldn't
|
||||
# do this blindly, but in practice in vLLM it's ok. The best
|
||||
# solution is to use auto_functionalization_v2 and then use
|
||||
# inductor's builtin defunctionalization (reinplacing) pass.
|
||||
mutated_args = {1: "query", 2: "key"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
|
||||
# rms_norm replacements avoid the most copies for LLaMa.
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||
mutated_args = {1: "input", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||
mutated_args = {1: "result", 2: "scale", 3: "residual"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif at_target in [
|
||||
torch.ops._C.rms_norm.default,
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
]:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
|
||||
and at_target
|
||||
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "allreduce_in",
|
||||
2: "residual",
|
||||
3: "norm_out",
|
||||
4: "quant_out",
|
||||
5: "scale_out",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args)
|
||||
# For some reason we need to specify the args for both
|
||||
# silu_and_mul and silu_and_mul_quant. The kwargs
|
||||
# pathway gets the wrong answer.
|
||||
elif at_target == torch.ops._C.silu_and_mul.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input")
|
||||
)
|
||||
elif at_target == torch.ops._C.silu_and_mul_quant.default:
|
||||
mutated_args = {1: "result"}
|
||||
self.defunctionalize(
|
||||
graph, node, mutated_args, args=("result", "input", "scale")
|
||||
)
|
||||
elif (
|
||||
hasattr(torch.ops._C, "silu_and_mul_nvfp4_quant")
|
||||
and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default
|
||||
):
|
||||
mutated_args = {1: "result", 2: "result_block_scale"}
|
||||
self.defunctionalize(
|
||||
graph,
|
||||
node,
|
||||
mutated_args,
|
||||
args=(
|
||||
"result",
|
||||
"result_block_scale",
|
||||
"input",
|
||||
"input_global_scale",
|
||||
),
|
||||
)
|
||||
# Defunctionalize fused_qk_norm_rope to remove higher-order wrapper.
|
||||
elif at_target == torch.ops._C.fused_qk_norm_rope.default:
|
||||
mutated_args = {1: "qkv"}
|
||||
args = (
|
||||
"qkv",
|
||||
"num_heads_q",
|
||||
"num_heads_k",
|
||||
"num_heads_v",
|
||||
"head_dim",
|
||||
"eps",
|
||||
"q_weight",
|
||||
"k_weight",
|
||||
"cos_sin_cache",
|
||||
"is_neox",
|
||||
"position_ids",
|
||||
)
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args, args=args)
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "fused_rope_and_unified_kv_cache_update")
|
||||
and at_target
|
||||
== torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default
|
||||
):
|
||||
mutated_args = {
|
||||
1: "query",
|
||||
2: "key",
|
||||
}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
# only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn
|
||||
elif (
|
||||
hasattr(torch.ops.vllm, "function_with_mutated_args_and_return")
|
||||
and at_target
|
||||
== torch.ops.vllm.function_with_mutated_args_and_return.default
|
||||
):
|
||||
mutated_args = {1: "x"}
|
||||
self.defunctionalize(graph, node, mutated_args=mutated_args)
|
||||
else:
|
||||
continue # skip the count
|
||||
|
||||
count += 1
|
||||
|
||||
self.dump_graph(graph, "before_cleanup")
|
||||
|
||||
# Remove the nodes all at once
|
||||
count_removed = len(self.nodes_to_remove)
|
||||
for node in self.nodes_to_remove:
|
||||
graph.erase_node(node)
|
||||
|
||||
logger.debug(
|
||||
"De-functionalized %s nodes, removed %s nodes", count, count_removed
|
||||
)
|
||||
self.nodes_to_remove.clear()
|
||||
|
||||
def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None:
|
||||
"""
|
||||
Stage a node (or nodes) for removal at the end of the pass.
|
||||
"""
|
||||
if isinstance(node_or_nodes, torch.fx.Node):
|
||||
self.nodes_to_remove.append(node_or_nodes)
|
||||
else:
|
||||
self.nodes_to_remove.extend(node_or_nodes)
|
||||
|
||||
def defunctionalize(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
mutated_args: dict[int, torch.fx.Node | str],
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
De-functionalize a node by replacing it with a call to the original.
|
||||
It also replaces the getitem users with the mutated arguments.
|
||||
See replace_users_with_mutated_args and insert_defunctionalized.
|
||||
"""
|
||||
self.replace_users_with_mutated_args(node, mutated_args)
|
||||
self.insert_defunctionalized(graph, node, args=args)
|
||||
self._remove(node)
|
||||
|
||||
def replace_users_with_mutated_args(
|
||||
self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str]
|
||||
) -> None:
|
||||
"""
|
||||
Replace mutated getitem users of the auto-functionalized node with the
|
||||
mutated arguments.
|
||||
:param node: The auto-functionalized node
|
||||
:param mutated_args: The mutated arguments, indexed by getitem index.
|
||||
If the value of an arg is a string, `node.kwargs[arg]` is used.
|
||||
"""
|
||||
for idx, user in self.getitem_users(node).items():
|
||||
# Some functionalized nodes may return both a result at getitem[0]
|
||||
# as well as mutated args at getitem[1:...]
|
||||
if idx == 0:
|
||||
assert idx not in mutated_args, (
|
||||
f"result at getitem[0] should not be in mutated_args for {node}"
|
||||
)
|
||||
continue
|
||||
arg = mutated_args[idx]
|
||||
arg = node.kwargs[arg] if isinstance(arg, str) else arg
|
||||
user.replace_all_uses_with(arg)
|
||||
self._remove(user)
|
||||
|
||||
def getitem_users(self, node: torch.fx.Node) -> dict[int, torch.fx.Node]:
|
||||
"""
|
||||
Returns the operator.getitem users of the auto-functionalized node,
|
||||
indexed by the index they are getting.
|
||||
"""
|
||||
users = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
idx = user.args[1]
|
||||
users[idx] = user
|
||||
return users
|
||||
|
||||
def insert_defunctionalized(
|
||||
self,
|
||||
graph: torch.fx.Graph,
|
||||
node: torch.fx.Node,
|
||||
args: tuple[torch.fx.Node | str, ...] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Insert a new defunctionalized node into the graph before node.
|
||||
If one of the kwargs is 'out', provide args directly,
|
||||
as node.kwargs cannot be used.
|
||||
See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351
|
||||
|
||||
:param graph: Graph to insert the defunctionalized node into
|
||||
:param node: The auto-functionalized node to defunctionalize
|
||||
:param args: If we cannot use kwargs, specify args directly.
|
||||
If an arg is a string, `node.kwargs[arg]` is used.
|
||||
""" # noqa: E501
|
||||
assert is_func(node, auto_functionalized), (
|
||||
f"node must be auto-functionalized, is {node} instead"
|
||||
)
|
||||
|
||||
# Create a new call to the original function
|
||||
with graph.inserting_before(node):
|
||||
function = node.args[0]
|
||||
if args is None:
|
||||
fn_node = graph.call_function(function, kwargs=node.kwargs)
|
||||
else:
|
||||
# Args passed as strings refer to items in node.kwargs
|
||||
args = tuple(
|
||||
node.kwargs[arg] if isinstance(arg, str) else arg for arg in args
|
||||
)
|
||||
fn_node = graph.call_function(function, args=args)
|
||||
|
||||
# If the function returns a value as well as mutating args inplace,
|
||||
# the functionalized node will have a getitem[0] user that holds this value
|
||||
# Replace getitem[0] user of the auto-functionalized node
|
||||
# with the new defunctionalized node directly if it exists
|
||||
users = self.getitem_users(node)
|
||||
if 0 in users:
|
||||
user = users[0]
|
||||
user.replace_all_uses_with(fn_node)
|
||||
self._remove(user)
|
||||
130
vllm/compilation/passes/utility/noop_elimination.py
Normal file
130
vllm/compilation/passes/utility/noop_elimination.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NoOpEliminationPass(VllmInductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape/slice operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case. Additionally, torch internal no-op elimination pass does
|
||||
not handle certain slice variants.
|
||||
|
||||
Cases handled:
|
||||
1. A chain of reshapes is equivalent to the last reshape called on the
|
||||
base tensor (input of the first reshape).
|
||||
2. A reshape that produces the shape of the input is redundant
|
||||
3. A slice that produces the shape of the input is redundant
|
||||
|
||||
Example graph 1:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 128, 32]" = torch.reshape(mul_1, [-1, 128, 32])
|
||||
view_2: "f16[s0, 4096]" = torch.reshape(view_2, [-1, 4096])
|
||||
view_3: "f16[s0, 128, 32]" = torch.reshape(view_3, [-1, 128, 32])
|
||||
|
||||
Can be replaced with:
|
||||
mul_1: "f16[s0, 4096]" = ...
|
||||
view_3: "f16[s0, 128, 32]" = ...
|
||||
|
||||
Example graph 2:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Example graph 3:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
slice_1: "f16[s0, 4096]" = torch.slice(scaled_mm, -1, 0, arg0)
|
||||
at = auto_functionalized(fused_add_rms_norm, input = slice_1, ...)
|
||||
out: "f16[s0, 4096]" = torch.slice_scatter(scaled_mm, at[1], 0, 0, arg0)
|
||||
|
||||
Can be replaced with:
|
||||
arg0: "s0" = SymInt(s0)
|
||||
scaled_mm: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(fused_add_rms_norm, input = scaled_mm, ...)
|
||||
out: "f16[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
# Case 1: rewrite reshape chains to reshapes on the base tensor
|
||||
input = node.args[0]
|
||||
# If the input is a reshape, rebind to that node
|
||||
if is_func(input, torch.ops.aten.reshape.default):
|
||||
# The new input is guaranteed not to be a reshape,
|
||||
# because we process nodes in order
|
||||
node.update_arg(0, input.args[0])
|
||||
if len(input.users) == 0:
|
||||
graph.erase_node(input)
|
||||
count += 1
|
||||
|
||||
# remove reshape/slice if it produces the original shape
|
||||
if is_func(node, torch.ops.aten.reshape.default) or is_func(
|
||||
node, torch.ops.aten.slice.Tensor
|
||||
):
|
||||
input = node.args[0]
|
||||
input_shape = input.meta["val"].shape
|
||||
output_shape = node.meta["val"].shape
|
||||
if self.all_dims_equivalent(input_shape, output_shape):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
elif is_func(node, torch.ops.aten.slice_scatter.default):
|
||||
base, view, dim_index, start, end = node.args[:5]
|
||||
base_shape = base.meta["val"].shape
|
||||
view_shape = view.meta["val"].shape
|
||||
|
||||
if self.all_dims_equivalent(base_shape, view_shape):
|
||||
node.replace_all_uses_with(view)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.debug("Removed %s no-op reshapes and slices", count)
|
||||
|
||||
# ---------------------- Shape comparison helpers ----------------------
|
||||
def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape/slice
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are two cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The dimensions both correspond to the same SymInt
|
||||
"""
|
||||
# Case 1
|
||||
return statically_known_true(dim == i_dim) # type: ignore[no-any-return]
|
||||
|
||||
def all_dims_equivalent(
|
||||
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
|
||||
) -> bool:
|
||||
dims_ = list(dims)
|
||||
i_dims_ = list(i_dims)
|
||||
if len(dims_) != len(i_dims_):
|
||||
# Different ranks can't be equivalent
|
||||
return False
|
||||
return all(self.dims_equivalent(s, i_s) for s, i_s in zip(dims, i_dims))
|
||||
21
vllm/compilation/passes/utility/post_cleanup.py
Normal file
21
vllm/compilation/passes/utility/post_cleanup.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from torch import fx
|
||||
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
|
||||
class PostCleanupPass(VllmInductorPass):
|
||||
"""
|
||||
This pass performs cleanup after custom passes.
|
||||
It topologically sorts the graph and removes unused nodes.
|
||||
This is needed because the pattern matcher does not guarantee producing
|
||||
a topologically sorted graph, and there may be unused nodes left around.
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
|
||||
stable_topological_sort(graph)
|
||||
graph.eliminate_dead_code()
|
||||
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
138
vllm/compilation/passes/utility/scatter_split_replace.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Replace ``slice_scatter`` and ``split_with_sizes`` nodes with a single
|
||||
assignment if there are no users for the inplace tensor written to by
|
||||
the slice_scatter call.
|
||||
|
||||
The inplace rotary_embedding custom op takes in mutable query and key inputs
|
||||
that are split+getitem outputs of a single qkv tensor.
|
||||
When functionalized, we fetch the rotated query and key from the functionalized op
|
||||
using `getitem` calls. However, we also write to the qkv tensor inplace using a
|
||||
`slice_scatter`, then split the inplace tensor to get the output tensors again.
|
||||
Instead, if the inplace tensor has no subsequent users, we can just replace the
|
||||
`slice_scatter` and `split_with_sizes` nodes with the `getitem` calls.
|
||||
|
||||
This is already done in fix_functionalization::FixFunctionalizationPass, but
|
||||
writing a custom pass for it before defunctionalization allows matching against the
|
||||
qkv split+rotary_embedding subpattern as part of e.g. the RoPE+KVCache fusion pass.
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ScatterSplitReplacementPass(VllmInductorPass):
|
||||
"""Replace getitem+slice_scatter+split nodes with a single getitem when
|
||||
the inplace subtensor written to by the slice_scatter has no other users.
|
||||
|
||||
Here's an example graph with q_size = 512, kv_size = 64:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
torch.ops.aten.slice_scatter.default(qkv, q, [0, 512], -1)
|
||||
torch.ops.aten.slice_scatter.default(qkv, k, [512, 512 + 64], -1)
|
||||
split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
q = operator.getitem(split_with_sizes_2, 0)
|
||||
k = operator.getitem(split_with_sizes_2, 1)
|
||||
v = operator.getitem(split_with_sizes_2, 2)
|
||||
|
||||
After this pass, this sequence of nodes is replaced with:
|
||||
split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(qkv, (512, 64, 64), -1)
|
||||
at = auto_functionalized(torch.ops._C.rotary_embedding.default(positions, q, k))
|
||||
q = operator.getitem(at, 1)
|
||||
k = operator.getitem(at, 2)
|
||||
v = operator.getitem(split_with_sizes_1, 2)
|
||||
"""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
target_ops = [torch.ops._C.rotary_embedding.default]
|
||||
if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
|
||||
target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default)
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, auto_functionalized):
|
||||
continue
|
||||
|
||||
kwargs = node.kwargs
|
||||
at_target = node.args[0]
|
||||
|
||||
if at_target in target_ops:
|
||||
query = kwargs["query"]
|
||||
key = kwargs["key"]
|
||||
getitem_nodes = {}
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
getitem_nodes[user.args[1]] = user
|
||||
|
||||
if (
|
||||
is_func(query, operator.getitem)
|
||||
and is_func(key, operator.getitem)
|
||||
and query.args[0] == key.args[0]
|
||||
and is_func(query.args[0], torch.ops.aten.split_with_sizes.default)
|
||||
and all(
|
||||
is_func(user, torch.ops.aten.slice_scatter.default)
|
||||
for getitem_node in getitem_nodes.values()
|
||||
for user in getitem_node.users
|
||||
)
|
||||
):
|
||||
# Pattern where query and key are slices of a qkv tensor.
|
||||
# While functionalized, results at [1] and [2] are scattered
|
||||
# back into qkv, then split again to get query and key.
|
||||
# If the inplace tensor has no other users, we can replace
|
||||
# the slice_scatter+split nodes with the original results.
|
||||
for user in getitem_nodes[1].users:
|
||||
slice_scatter_1_node = user
|
||||
if not is_func(
|
||||
slice_scatter_1_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in getitem_nodes[2].users:
|
||||
slice_scatter_2_node = user
|
||||
if not is_func(
|
||||
slice_scatter_2_node, torch.ops.aten.slice_scatter.default
|
||||
):
|
||||
continue
|
||||
|
||||
for user in slice_scatter_2_node.users:
|
||||
split_node = user
|
||||
if not is_func(split_node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
|
||||
split_getitem_users = {}
|
||||
for user in split_node.users:
|
||||
if is_func(user, operator.getitem):
|
||||
split_getitem_users[user.args[1]] = user
|
||||
|
||||
# Replace query node
|
||||
split_getitem_users[0].replace_all_uses_with(getitem_nodes[1])
|
||||
graph.erase_node(split_getitem_users[0])
|
||||
# Replace key node
|
||||
split_getitem_users[1].replace_all_uses_with(getitem_nodes[2])
|
||||
graph.erase_node(split_getitem_users[1])
|
||||
# Redirect value node to original qkv tensor
|
||||
split_getitem_users[2].replace_input_with(split_node, query.args[0])
|
||||
|
||||
# Erase unused nodes
|
||||
graph.erase_node(split_node)
|
||||
graph.erase_node(slice_scatter_2_node)
|
||||
graph.erase_node(slice_scatter_1_node)
|
||||
|
||||
count += 1
|
||||
|
||||
logger.debug("Eliminated %d slice_scatter+split nodes", count)
|
||||
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
70
vllm/compilation/passes/utility/split_coalescing.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Coalesce duplicate ``split_with_sizes`` nodes that operate on the same
|
||||
input tensor with the same split sizes.
|
||||
|
||||
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor
|
||||
graph may contain multiple ``split_with_sizes`` calls on the same tensor
|
||||
that CSE fails to merge. This pass detects and replaces the duplicates
|
||||
so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion)
|
||||
see a single split node with all users attached.
|
||||
|
||||
See also:
|
||||
- vLLM #33295 (original issue)
|
||||
- PyTorch #174472 (upstream CSE gap)
|
||||
"""
|
||||
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from ..fx_utils import is_func
|
||||
from ..vllm_inductor_pass import VllmInductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SplitCoalescingPass(VllmInductorPass):
|
||||
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
|
||||
node when they share the same input tensor and split sizes."""
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
count = 0
|
||||
|
||||
# Map from input tensor node -> list of split nodes seen so far.
|
||||
split_nodes: dict[fx.Node, list[fx.Node]] = {}
|
||||
|
||||
for node in graph.nodes:
|
||||
if not is_func(node, torch.ops.aten.split_with_sizes.default):
|
||||
continue
|
||||
if not all(is_func(user, operator.getitem) for user in node.users):
|
||||
continue
|
||||
|
||||
arg_node, split_sizes = node.args[:2]
|
||||
|
||||
if arg_node not in split_nodes:
|
||||
split_nodes[arg_node] = [node]
|
||||
continue
|
||||
|
||||
# Find existing node with same split_sizes
|
||||
canonical = next(
|
||||
(
|
||||
n
|
||||
for n in split_nodes[arg_node]
|
||||
if list(n.args[1]) == list(split_sizes)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if canonical is not None:
|
||||
node.replace_all_uses_with(canonical)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
else:
|
||||
split_nodes[arg_node].append(node)
|
||||
|
||||
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|
||||
180
vllm/compilation/passes/vllm_inductor_pass.py
Normal file
180
vllm/compilation/passes/vllm_inductor_pass.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
from torch._inductor.pattern_matcher import PatternMatcherPass, PatternPrettyPrinter
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inductor_pass import InductorPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InductorCompilationConfig:
|
||||
splitting_ops: list[str] | None = None
|
||||
use_inductor_graph_partition: bool = False
|
||||
|
||||
|
||||
class VllmInductorPass(InductorPass):
|
||||
"""
|
||||
An inductor pass with access to vLLM PassConfig.
|
||||
It provides timing, logging, and dumping utilities.
|
||||
"""
|
||||
|
||||
dump_prefix: ClassVar[int | None] = None
|
||||
"""Keep track of pass index for debug dump ordering."""
|
||||
|
||||
def __init__(self, config: VllmConfig):
|
||||
# Get only the necessary CompilationConfig for the inductor pass, since
|
||||
# full `CompilationConfig` contains pointer to model which is unsafe.
|
||||
self.compilation_config = InductorCompilationConfig(
|
||||
splitting_ops=config.compilation_config.splitting_ops,
|
||||
use_inductor_graph_partition=config.compilation_config.use_inductor_graph_partition,
|
||||
)
|
||||
self.pass_config = config.compilation_config.pass_config
|
||||
self.model_dtype = config.model_config.dtype if config.model_config else None
|
||||
self.device: str | None = (
|
||||
config.device_config.device if config.device_config else None
|
||||
)
|
||||
self.pass_name = self.__class__.__name__
|
||||
|
||||
@staticmethod
|
||||
def time_and_log(
|
||||
call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None],
|
||||
) -> Callable[["VllmInductorPass", torch.fx.Graph], None]:
|
||||
@functools.wraps(call_fn)
|
||||
def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None:
|
||||
self.begin()
|
||||
self.dump_graph(graph, "before")
|
||||
call_fn(self, graph)
|
||||
self.dump_graph(graph, "after")
|
||||
self.end_and_log()
|
||||
|
||||
return wrapped
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None:
|
||||
i = VllmInductorPass.dump_prefix
|
||||
i_str = "" if i is None else f".{i}"
|
||||
lazy_format_graph_code(
|
||||
f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module
|
||||
)
|
||||
|
||||
def begin(self) -> None:
|
||||
self._start_time = time.perf_counter_ns()
|
||||
|
||||
def end_and_log(self) -> None:
|
||||
self._end_time = time.perf_counter_ns()
|
||||
duration_ms = float(self._end_time - self._start_time) / 1.0e6
|
||||
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)
|
||||
|
||||
|
||||
class VllmPatternMatcherPass(VllmInductorPass):
|
||||
"""
|
||||
A VllmInductorPass that uses the Inductor pattern matcher.
|
||||
Its main use is providing the dump_patterns utility that dumps the
|
||||
Inductor pattern matcher patterns into a file, which greatly aids debugging.
|
||||
|
||||
TODO(luka) move more utilities to this pass.
|
||||
"""
|
||||
|
||||
matched_count: int = 0
|
||||
"""The number of matched patterns in the pass."""
|
||||
|
||||
_OP_OVERLOAD_PATTERN: ClassVar[re.Pattern] = re.compile(
|
||||
r"<OpOverload\(op='([^']*)', overload='([^']*)'\)>"
|
||||
)
|
||||
|
||||
def _replace_op_overloads(self, string: str) -> str:
|
||||
"""Replace <OpOverload(..., ...)> with nicer formulations"""
|
||||
return str(
|
||||
self._OP_OVERLOAD_PATTERN.sub(
|
||||
lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}",
|
||||
string,
|
||||
)
|
||||
)
|
||||
|
||||
def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None:
|
||||
"""
|
||||
If debug dumping is enabled, dump the Inductor pattern-matcher patterns
|
||||
into the debug_dump_path folder next to the dumped fx graphs.
|
||||
|
||||
This method does its best to print something that looks like Python code
|
||||
for easier debugging and potentially navigation. If any errors appear in
|
||||
the output, please add to this method.
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compile_debug_dump_path()
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils.system_utils import unique_filepath
|
||||
|
||||
file_path = unique_filepath(
|
||||
lambda i: debug_dump_path / f"patterns.{self.pass_name}.{i}.py"
|
||||
)
|
||||
|
||||
with file_path.open("w") as f:
|
||||
print(
|
||||
f"# This file was produced by VllmPatternMatcherPass."
|
||||
f"dump_patterns for {self.pass_name}.\n"
|
||||
f"# It does its best to produce valid-Python-looking code but"
|
||||
f" please add to dump_patterns if there are any errors.\n\n"
|
||||
f"from torch._higher_order_ops.auto_functionalize import "
|
||||
f"auto_functionalized as auto_functionalized\n"
|
||||
f"from torch._inductor.pattern_matcher import *\n"
|
||||
f"vllm = torch.ops.vllm",
|
||||
file=f,
|
||||
)
|
||||
|
||||
for node, patterns in pm_pass.patterns.items():
|
||||
# fix the operator.getitem repr
|
||||
if node[1] == operator.getitem:
|
||||
node_repr = f"({repr(node[0])}, operator.getitem)"
|
||||
else:
|
||||
node_repr = repr(node)
|
||||
|
||||
node_repr = self._replace_op_overloads(node_repr)
|
||||
|
||||
print(f"\n\n# Patterns for op: {node_repr}", file=f)
|
||||
for i, pattern in enumerate(patterns):
|
||||
# reserve auto_functionalized ahead of time
|
||||
pp = PatternPrettyPrinter()
|
||||
pp.namespace.create_name("auto_functionalized", None)
|
||||
|
||||
# Assemble pattern
|
||||
out_node = pp.pretty_print(pattern.pattern)
|
||||
pattern_repr = "\n".join(
|
||||
[f"def pattern_{i}():"]
|
||||
+ [
|
||||
f"{pp.memoized_objs_names[key]} = "
|
||||
f"{pp.memoized_objs_pp[key]}"
|
||||
for key in pp.memoized_objs_names
|
||||
]
|
||||
+ [f"return {out_node}"]
|
||||
).replace("\n", "\n ")
|
||||
|
||||
pattern_repr = self._replace_op_overloads(pattern_repr)
|
||||
print(f"{pattern_repr}\n", file=f)
|
||||
|
||||
|
||||
class PrinterInductorPass(VllmInductorPass):
|
||||
def __init__(self, name: str, config: VllmConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.name = name
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph) -> None:
|
||||
self.dump_graph(graph, self.name)
|
||||
343
vllm/compilation/piecewise_backend.py
Normal file
343
vllm/compilation/piecewise_backend.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pickle import Pickler
|
||||
from typing import Any
|
||||
|
||||
import torch._functorch.config
|
||||
import torch.fx as fx
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
from torch._logging._internal import trace_structured
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RangeEntry:
|
||||
compile_range: Range
|
||||
compiled: bool = False
|
||||
runnable: Callable[..., Any] = None # type: ignore
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule | None,
|
||||
vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
vllm_backend: VllmBackend,
|
||||
returns_tuple: bool,
|
||||
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
|
||||
submod_name: str = "",
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation of static shapes and
|
||||
dispatching based on runtime shape.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
This class supports two mutually exclusive modes:
|
||||
1. Compilation (graph is set, compiled_runnables is None):
|
||||
Used during initial compilation when we have the FX graph
|
||||
and need to compile it for each shape range.
|
||||
2. Precompilation (graph is None, compiled_runnables is set):
|
||||
Used when loading from cache/AOT artifacts where we already
|
||||
have pre-compiled callables and don't need the original graph.
|
||||
|
||||
Exactly one of graph or compiled_runnables must be provided.
|
||||
"""
|
||||
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
|
||||
"exactly one of graph and compiled_runnables should be set."
|
||||
)
|
||||
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
self.compiled_runnables = compiled_runnables
|
||||
self.submod_name = submod_name
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
|
||||
self.is_full_graph = total_piecewise_compiles == 1
|
||||
self.is_encoder_compilation = vllm_backend.is_encoder
|
||||
|
||||
self.compile_ranges = self.compilation_config.get_compile_ranges()
|
||||
if self.is_encoder_compilation:
|
||||
# For encoder compilation we use the max int32 value
|
||||
# to set the upper bound of the compile ranges
|
||||
max_int32 = 2**31 - 1
|
||||
last_compile_range = self.compile_ranges[-1]
|
||||
assert (
|
||||
last_compile_range.end
|
||||
== vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.compile_ranges[-1] = Range(
|
||||
start=last_compile_range.start, end=max_int32
|
||||
)
|
||||
|
||||
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.compile_sizes = self.compilation_config.compile_sizes
|
||||
log_string = f"PiecewiseBackend: compile_sizes: {self.compile_sizes}"
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
self.returns_tuple = returns_tuple
|
||||
|
||||
# the entries for ranges that we need to either
|
||||
self.range_entries: dict[Range, RangeEntry] = {}
|
||||
|
||||
# to_be_compiled_ranges tracks the remaining ranges to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
|
||||
|
||||
# We only keep compilation management inside this class directly.
|
||||
if self.compile_sizes is not None:
|
||||
for size in self.compile_sizes:
|
||||
if isinstance(size, str):
|
||||
assert size == "cudagraph_capture_sizes"
|
||||
raise NotImplementedError(
|
||||
"cudagraph_capture_sizes not supported in compile_sizes."
|
||||
"This should be handled in `post_init_cudagraph_sizes`."
|
||||
)
|
||||
else:
|
||||
assert isinstance(size, int)
|
||||
range = Range(start=size, end=size)
|
||||
if range not in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
self.to_be_compiled_ranges.add(range)
|
||||
|
||||
for range in self.compile_ranges:
|
||||
self.range_entries[range] = RangeEntry(
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
# Track whether we've logged the graph for this subgraph (only log once)
|
||||
self._graph_logged = False
|
||||
|
||||
# get the on_compilation_complete callback from context...
|
||||
# PiecewiseBackend is created during the first call,
|
||||
# which is when the context is set (see compilation/decorators.py)
|
||||
from vllm.compilation.backends import _on_compilation_complete_callback
|
||||
|
||||
self.on_compilation_complete = _on_compilation_complete_callback.get()
|
||||
|
||||
def get_compiled_graph_wrapper(
|
||||
self, compiled_graph: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
def compiled_graph_wrapper(*args: Any) -> Any:
|
||||
graph_output = compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
def check_for_ending_compilation(self) -> None:
|
||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
time_before_saving = time.perf_counter()
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
elapsed = time.perf_counter() - time_before_saving
|
||||
if elapsed > 1:
|
||||
logger.info_once(
|
||||
"Saved compiler manager cache in %.2f seconds.",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
# Call the completion callback (e.g., to save AOT compiled function)
|
||||
if self.on_compilation_complete is not None:
|
||||
self.on_compilation_complete()
|
||||
|
||||
def to_bytes(self) -> dict[str, bytes]:
|
||||
class StandaloneCompiledArtifactsPickler(Pickler):
|
||||
def reducer_override(self, obj: object) -> Any:
|
||||
if isinstance(obj, CachingAutotuner):
|
||||
obj.prepare_for_pickle()
|
||||
return pickle.loads, (
|
||||
pickle.dumps(
|
||||
obj,
|
||||
),
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def serialize(fn: Callable[..., Any]) -> bytes:
|
||||
assert hasattr(fn, "serialize"), "fn must have serialize method"
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
entry = fn.serialize()
|
||||
|
||||
f = io.BytesIO()
|
||||
StandaloneCompiledArtifactsPickler(f).dump(entry)
|
||||
result = f.getvalue()
|
||||
return result
|
||||
|
||||
out = {}
|
||||
|
||||
for range_key, entry in self.range_entries.items():
|
||||
if not entry.compiled:
|
||||
logger.debug(
|
||||
"entry with range %s not compiled, so cannot get its bytes",
|
||||
range_key,
|
||||
)
|
||||
continue
|
||||
if hasattr(entry.runnable, "serialize"):
|
||||
out[str(range_key)] = serialize(entry.runnable)
|
||||
|
||||
return out
|
||||
|
||||
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
|
||||
# We need to pass fake example_inputs, otherwise torch.compile
|
||||
# will fakify the example_inputs potentially causing some non dynamic
|
||||
# dimension to be be duck shaped to other existing shapes that have hints
|
||||
# matching their values.
|
||||
# This is problem because it can lead to unintended specializations!
|
||||
# if the new wrongly dynamic dim is specialized
|
||||
# it will force specializing the whole shape
|
||||
# torch.compile probably should not accept
|
||||
# non fake tensors as example inputs!
|
||||
# See issue https://github.com/vllm-project/vllm/issues/27899
|
||||
fake_example_inputs = []
|
||||
assert self.graph is not None
|
||||
for node in self.graph.graph.nodes:
|
||||
# All place holders come first
|
||||
if node.op == "placeholder":
|
||||
fake_example_inputs.append(node.meta["example_value"])
|
||||
else:
|
||||
break
|
||||
assert len(fake_example_inputs) == len(args)
|
||||
return fake_example_inputs
|
||||
|
||||
def _log_compile_start(self, compile_range: Range):
|
||||
"""Log compilation event for TORCH_TRACE/tlparse."""
|
||||
is_cudagraph_size = (
|
||||
self.compile_sizes is not None and compile_range.start in self.compile_sizes
|
||||
)
|
||||
subgraph_index = self.piecewise_compile_index
|
||||
submod_name = self.submod_name
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "vllm_piecewise_compile_start",
|
||||
"encoding": "json",
|
||||
},
|
||||
payload_fn=lambda: json.dumps(
|
||||
{
|
||||
"piecewise_index": subgraph_index,
|
||||
"submod_name": submod_name,
|
||||
"total_piecewise_compiles": self.total_piecewise_compiles,
|
||||
"compile_range_start": compile_range.start,
|
||||
"compile_range_end": compile_range.end,
|
||||
"is_single_size": compile_range.is_single_size(),
|
||||
"is_cudagraph_capture_size": is_cudagraph_size,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Log the subgraph graph dump only once per subgraph (not per size)
|
||||
# to reduce log file size. The graph code is the same for all sizes.
|
||||
if not self._graph_logged:
|
||||
self._graph_logged = True
|
||||
assert self.graph is not None
|
||||
trace_structured(
|
||||
"graph_dump",
|
||||
metadata_fn=lambda: {
|
||||
"name": f"vllm_{submod_name}",
|
||||
},
|
||||
payload_fn=lambda: self.graph.print_readable(print_output=False),
|
||||
)
|
||||
|
||||
def _maybe_compile_for_range_entry(
|
||||
self, range_entry: RangeEntry, args: tuple[Any, ...]
|
||||
) -> Any:
|
||||
if not range_entry.compiled:
|
||||
if self.compiled_runnables is not None:
|
||||
range_entry.runnable = self.get_compiled_graph_wrapper(
|
||||
self.compiled_runnables[str(range_entry.compile_range)]
|
||||
)
|
||||
else:
|
||||
self._log_compile_start(range_entry.compile_range)
|
||||
|
||||
# args are real arguments
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args_list = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else list(args)
|
||||
)
|
||||
|
||||
with (
|
||||
torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
):
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
)
|
||||
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
|
||||
# First we try to find the range entry for the concrete compile size
|
||||
# If not found, we search for the range entry
|
||||
# that contains the runtime shape.
|
||||
if self.compile_sizes is None:
|
||||
return None
|
||||
|
||||
if runtime_shape in self.compile_sizes:
|
||||
return self.range_entries[Range(start=runtime_shape, end=runtime_shape)]
|
||||
else:
|
||||
for range in self.compile_ranges:
|
||||
if runtime_shape in range:
|
||||
return self.range_entries[range]
|
||||
return None
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
range_entry = self._find_range_for_shape(runtime_shape)
|
||||
|
||||
assert range_entry is not None, (
|
||||
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
|
||||
)
|
||||
|
||||
self._maybe_compile_for_range_entry(range_entry, args)
|
||||
return range_entry.runnable(*args)
|
||||
321
vllm/compilation/wrapper.py
Normal file
321
vllm/compilation/wrapper.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import CodeType
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
R = TypeVar("R")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def _noop_add_global_state_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
def _noop_add_torch_function_mode_stack_guard(
|
||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compilation_context() -> Generator[None, None, None]:
|
||||
"""Context manager for compilation settings and patches.
|
||||
|
||||
This manager:
|
||||
1. Sets higher dynamo cache limits for compilation. (Needed for
|
||||
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||
Generally a recompilation can happen whenever we use a new
|
||||
backend instance in torch.compile.
|
||||
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
||||
3. Patches out add_torch_function_mode_stack_guard to skip
|
||||
TORCH_FUNCTION_MODE_STACK guards.
|
||||
4. Restores everything when compilation completes
|
||||
"""
|
||||
# Save original values
|
||||
original_global_state_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
||||
)
|
||||
original_torch_function_mode_stack_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
||||
)
|
||||
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||
|
||||
try:
|
||||
# Set higher cache limits for compilation
|
||||
torch._dynamo.config.cache_size_limit = 2048
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||
|
||||
# Patch guard manager
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
_noop_add_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
_noop_add_torch_function_mode_stack_guard
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original values
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
original_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
original_torch_function_mode_stack_guard
|
||||
)
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||
|
||||
|
||||
class TorchCompileWithNoGuardsWrapper:
|
||||
"""
|
||||
A wrapper class for torch.compile, it ensures that all guards are dropped
|
||||
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
||||
When guards are dropped, the first time __call__ is invoked, a single
|
||||
compilation is triggered. Dynamo should never be traced again after that
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
|
||||
assert hasattr(self, "_check_shape_invariants")
|
||||
self._check_shape_invariants(*args, **kwargs)
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def _call_with_optional_nvtx_range(
|
||||
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
||||
) -> Any:
|
||||
if self.layerwise_nvtx_tracing_enabled:
|
||||
args_list = list(args)
|
||||
kwargs_dict = dict(kwargs)
|
||||
with layerwise_nvtx_marker_context(
|
||||
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
|
||||
self,
|
||||
in_tensor=args_list,
|
||||
kwargs=kwargs_dict,
|
||||
) as ctx:
|
||||
ctx.result = callable_fn(*args, **kwargs)
|
||||
return ctx.result
|
||||
return callable_fn(*args, **kwargs)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
mode = vllm_config.compilation_config.mode
|
||||
self.layerwise_nvtx_tracing_enabled = (
|
||||
vllm_config.observability_config.enable_layerwise_nvtx_tracing
|
||||
)
|
||||
if mode is None:
|
||||
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = {}
|
||||
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
self.first_compile = True
|
||||
self.evaluate_guards = (
|
||||
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||
)
|
||||
|
||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
if self.evaluate_guards:
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"compilation_config.dynamic_shapes_config.evaluate_guards "
|
||||
"requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
|
||||
options["guard_filter_fn"] = lambda x: [
|
||||
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||
]
|
||||
else:
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
compiled_ptr: Any = self.forward
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
# reason is that bytecode does torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation. And if we use
|
||||
# compiled_ptr = self.check_invariants_and_forward
|
||||
# it will reset all entries.
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
aot_context = nullcontext()
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
with aot_context:
|
||||
self._compiled_callable = torch.compile(
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
self._compiled_bytecode: CodeType | None = None
|
||||
|
||||
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||
raise RuntimeError(
|
||||
"aot_compile is not supported by the current configuration. "
|
||||
"Please make sure torch.compile is enabled with the latest "
|
||||
f"version of PyTorch (current using torch: {torch.__version__})"
|
||||
)
|
||||
return self._compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
if not self._compiled_bytecode:
|
||||
# Make sure a compilation is triggered by clearing dynamo
|
||||
# cache.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
with self._dispatch_to_compiled_code():
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self.forward, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if self.first_compile or not self.evaluate_guards
|
||||
else torch.compiler.set_stance("fail_on_recompile")
|
||||
)
|
||||
self.first_compile = False
|
||||
with _compilation_context(), ctx:
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def original_code_object(self) -> CodeType:
|
||||
"""Return the original code object of the forward method."""
|
||||
return self.__class__.forward.__code__
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object():
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
while frame and frame.f_back:
|
||||
frame = frame.f_back
|
||||
code_name = frame.f_code.co_name
|
||||
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
||||
if code_name == "_compile" and file_name == "convert_frame.py":
|
||||
break
|
||||
frame = frame.f_locals["frame"]
|
||||
assert frame.f_code == old_code
|
||||
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self._compiled_bytecode = new_code
|
||||
|
||||
path = self.vllm_config.compile_debug_dump_path()
|
||||
if path:
|
||||
decompiled_file = path / "transformed_code.py"
|
||||
if not decompiled_file.exists():
|
||||
try:
|
||||
# usually the decompilation will succeed for most models,
|
||||
# as we guarantee a full-graph compilation in Dynamo.
|
||||
# but there's no 100% guarantee, since decompliation is
|
||||
# not a reversible process.
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
|
||||
with open(decompiled_file, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and "update" in new_code.co_names
|
||||
):
|
||||
import depyf
|
||||
|
||||
src = depyf.decompile(new_code)
|
||||
msg = (
|
||||
"Assigning / modifying buffers of nn.Module during forward pass is not "
|
||||
"allowed when using cudagraph inside the compiler because it will "
|
||||
"cause silent errors. Please use eager mode or fix the code. The "
|
||||
"following code contains clues about which buffer is being modified "
|
||||
f"(please search for the usage of the function `update`):\n{src}"
|
||||
)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
|
||||
# noqa: E501
|
||||
"""
|
||||
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa: E501 line too long
|
||||
original = self.original_code_object()
|
||||
assert self._compiled_bytecode is not None
|
||||
self.__class__.forward.__code__ = self._compiled_bytecode
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
130
vllm/config/__init__.py
Normal file
130
vllm/config/__init__.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.attention import AttentionConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
from vllm.config.compilation import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
PassConfig,
|
||||
)
|
||||
from vllm.config.device import DeviceConfig
|
||||
from vllm.config.ec_transfer import ECTransferConfig
|
||||
from vllm.config.kernel import KernelConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.config.model import (
|
||||
ModelConfig,
|
||||
iter_architecture_defaults,
|
||||
str_dtype_to_torch_dtype,
|
||||
try_match_architecture_defaults,
|
||||
)
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config.observability import ObservabilityConfig
|
||||
from vllm.config.offload import (
|
||||
OffloadBackend,
|
||||
OffloadConfig,
|
||||
PrefetchOffloadConfig,
|
||||
UVAOffloadConfig,
|
||||
)
|
||||
from vllm.config.parallel import EPLBConfig, ParallelConfig
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.profiler import ProfilerConfig
|
||||
from vllm.config.scheduler import SchedulerConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
from vllm.config.utils import (
|
||||
ConfigType,
|
||||
SupportsMetricsInfo,
|
||||
config,
|
||||
get_attr_docs,
|
||||
is_init_field,
|
||||
replace,
|
||||
update_config,
|
||||
)
|
||||
from vllm.config.vllm import (
|
||||
VllmConfig,
|
||||
get_cached_compilation_config,
|
||||
get_current_vllm_config,
|
||||
get_current_vllm_config_or_none,
|
||||
get_layers_from_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.weight_transfer import WeightTransferConfig
|
||||
|
||||
# __all__ should only contain classes and functions.
|
||||
# Types and globals should be imported from their respective modules.
|
||||
__all__ = [
|
||||
# From vllm.config.attention
|
||||
"AttentionConfig",
|
||||
# From vllm.config.cache
|
||||
"CacheConfig",
|
||||
# From vllm.config.compilation
|
||||
"CompilationConfig",
|
||||
"CompilationMode",
|
||||
"CUDAGraphMode",
|
||||
"PassConfig",
|
||||
# From vllm.config.device
|
||||
"DeviceConfig",
|
||||
# From vllm.config.ec_transfer
|
||||
"ECTransferConfig",
|
||||
# From vllm.config.kernel
|
||||
"KernelConfig",
|
||||
# From vllm.config.kv_events
|
||||
"KVEventsConfig",
|
||||
# From vllm.config.kv_transfer
|
||||
"KVTransferConfig",
|
||||
# From vllm.config.load
|
||||
"LoadConfig",
|
||||
# From vllm.config.lora
|
||||
"LoRAConfig",
|
||||
# From vllm.config.model
|
||||
"ModelConfig",
|
||||
"iter_architecture_defaults",
|
||||
"str_dtype_to_torch_dtype",
|
||||
"try_match_architecture_defaults",
|
||||
# From vllm.config.multimodal
|
||||
"MultiModalConfig",
|
||||
# From vllm.config.observability
|
||||
"ObservabilityConfig",
|
||||
# From vllm.config.offload
|
||||
"OffloadBackend",
|
||||
"OffloadConfig",
|
||||
"PrefetchOffloadConfig",
|
||||
"UVAOffloadConfig",
|
||||
# From vllm.config.parallel
|
||||
"EPLBConfig",
|
||||
"ParallelConfig",
|
||||
# From vllm.config.pooler
|
||||
"PoolerConfig",
|
||||
# From vllm.config.scheduler
|
||||
"SchedulerConfig",
|
||||
# From vllm.config.speculative
|
||||
"SpeculativeConfig",
|
||||
# From vllm.config.speech_to_text
|
||||
"SpeechToTextConfig",
|
||||
# From vllm.config.structured_outputs
|
||||
"StructuredOutputsConfig",
|
||||
# From vllm.config.profiler
|
||||
"ProfilerConfig",
|
||||
# From vllm.config.utils
|
||||
"ConfigType",
|
||||
"SupportsMetricsInfo",
|
||||
"config",
|
||||
"get_attr_docs",
|
||||
"is_init_field",
|
||||
"replace",
|
||||
"update_config",
|
||||
# From vllm.config.vllm
|
||||
"VllmConfig",
|
||||
"get_cached_compilation_config",
|
||||
"get_current_vllm_config",
|
||||
"get_current_vllm_config_or_none",
|
||||
"set_current_vllm_config",
|
||||
"get_layers_from_vllm_config",
|
||||
"WeightTransferConfig",
|
||||
]
|
||||
69
vllm/config/attention.py
Normal file
69
vllm/config/attention.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
@config
|
||||
class AttentionConfig:
|
||||
"""Configuration for attention mechanisms in vLLM."""
|
||||
|
||||
backend: AttentionBackendEnum | None = None
|
||||
"""Attention backend to use. If None, will be selected automatically."""
|
||||
|
||||
flash_attn_version: Literal[2, 3] | None = None
|
||||
"""Force vllm to use a specific flash-attention version (2 or 3).
|
||||
Only valid when using the flash-attention backend."""
|
||||
|
||||
use_prefill_decode_attention: bool = False
|
||||
"""Use separate prefill and decode kernels for attention instead of
|
||||
the unified triton kernel."""
|
||||
|
||||
flash_attn_max_num_splits_for_cuda_graph: int = 32
|
||||
"""Flash Attention max number splits for cuda graph decode."""
|
||||
|
||||
use_cudnn_prefill: bool = False
|
||||
"""Whether to use cudnn prefill."""
|
||||
|
||||
use_trtllm_ragged_deepseek_prefill: bool = True
|
||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||
|
||||
use_trtllm_attention: bool | None = None
|
||||
"""If set to True/False, use or don't use the TRTLLM attention backend
|
||||
in flashinfer. If None, auto-detect the attention backend in flashinfer."""
|
||||
|
||||
disable_flashinfer_prefill: bool = False
|
||||
"""Whether to disable flashinfer prefill."""
|
||||
|
||||
disable_flashinfer_q_quantization: bool = False
|
||||
"""If set, when using fp8 kv, do not quantize Q to fp8."""
|
||||
|
||||
use_prefill_query_quantization: bool = False
|
||||
"""If set, quantize query for attention in prefill."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
ignored_factors: list[str] = []
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
@field_validator("backend", mode="before")
|
||||
@classmethod
|
||||
def validate_backend_before(cls, value: Any) -> Any:
|
||||
"""Enable parsing of the `backend` enum type from string."""
|
||||
if isinstance(value, str):
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
return value
|
||||
250
vllm/config/cache.py
Normal file
250
vllm/config/cache.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import Field, SkipValidation, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import format_gib, get_cpu_memory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
else:
|
||||
ParallelConfig = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
|
||||
CacheDType = Literal[
|
||||
"auto",
|
||||
"bfloat16",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
"fp8_inc",
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
MambaDType = Literal["auto", "float32", "float16"]
|
||||
MambaCacheMode = Literal["all", "align", "none"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
|
||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
|
||||
@config
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
|
||||
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
|
||||
only block sizes up to 32 are supported.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `Platform.check_and_update_config()` based on the current
|
||||
platform."""
|
||||
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
|
||||
"""The fraction of GPU memory to be used for the model executor, which can
|
||||
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||
utilization. If unspecified, will use the default value of 0.9. This is a
|
||||
per-instance limit, and only applies to the current vLLM instance. It does
|
||||
not matter if you have another vLLM instance running on the same GPU. For
|
||||
example, if you have two vLLM instances running on the same GPU, you can
|
||||
set the GPU memory utilization to 0.5 for each instance."""
|
||||
swap_space: float = Field(default=4, ge=0)
|
||||
"""Size of the CPU swap space per GPU (in GiB)."""
|
||||
cache_dtype: CacheDType = "auto"
|
||||
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
|
||||
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
|
||||
bfloat16 instead, this is an invalid option for models that do not default
|
||||
to fp8.
|
||||
"""
|
||||
is_attention_free: bool = False
|
||||
"""Whether the model is attention-free. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
num_gpu_blocks_override: int | None = None
|
||||
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
|
||||
if specified. Does nothing if `None`. Used for testing preemption."""
|
||||
sliding_window: int | None = None
|
||||
"""Sliding window size for the KV cache. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
enable_prefix_caching: bool = True
|
||||
"""Whether to enable prefix caching."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "sha256" uses Pickle for object serialization before hashing. This is the
|
||||
current default, as SHA256 is the most secure choice to avoid potential
|
||||
hash collisions.\n
|
||||
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
|
||||
serializes objects using canonical CBOR and hashes them with SHA-256.\n
|
||||
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
|
||||
non-cryptographic hashing. Requires the optional ``xxhash`` package.
|
||||
IMPORTANT: Use of a hashing algorithm that is not considered
|
||||
cryptographically secure theoretically increases the risk of hash collisions,
|
||||
which can cause undefined behavior or even leak private information in
|
||||
multi-tenant environments. Even if collisions are still very unlikely, it is
|
||||
important to consider your security risk tolerance against the performance
|
||||
benefits before turning this on.\n
|
||||
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
|
||||
reproducible hashing. Requires the optional ``xxhash`` package."""
|
||||
cpu_offload_gb: float = Field(default=0, ge=0)
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
increase the GPU memory size. For example, if you have one 24 GB GPU and
|
||||
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
|
||||
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
|
||||
Note that this requires fast CPU-GPU interconnect, as part of the model is
|
||||
loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||
|
||||
DEPRECATED: This field is deprecated and will be removed in v0.16.
|
||||
Please use OffloadConfig.uva.cpu_offload_gb instead.
|
||||
"""
|
||||
cpu_offload_params: set[str] = Field(default_factory=set)
|
||||
"""The set of parameter name segments to target for CPU offloading.
|
||||
|
||||
DEPRECATED: This field is deprecated and will be removed in v0.16.
|
||||
Please use OffloadConfig.uva.cpu_offload_params instead.
|
||||
"""
|
||||
calculate_kv_scales: bool = False
|
||||
"""This enables dynamic calculation of `k_scale` and `v_scale` when
|
||||
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
|
||||
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
||||
cpu_kvcache_space_bytes: int | None = None
|
||||
"""(CPU backend only) CPU key-value cache space."""
|
||||
mamba_page_size_padded: int | None = None
|
||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||
models to ensure exact alignment with attention page size."""
|
||||
mamba_block_size: int | None = Field(default=None, gt=0)
|
||||
"""Size of a contiguous cache block in number of tokens for mamba cache.
|
||||
Can be set only when prefix caching is enabled.
|
||||
Value must be a multiple of 8 to align with causal_conv1d kernel."""
|
||||
mamba_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||
config."""
|
||||
mamba_ssm_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (ssm state only, conv state will
|
||||
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
|
||||
for the ssm state will be determined by mamba_cache_dtype."""
|
||||
mamba_cache_mode: MambaCacheMode = "none"
|
||||
"""The cache strategy for Mamba layers.
|
||||
- "none": set when prefix caching is disabled.
|
||||
- "all": cache the mamba state of all tokens at position i * block_size. This is
|
||||
the default behavior (for models that support it) when prefix caching is
|
||||
enabled.
|
||||
- "align": only cache the mamba state of the last token of each scheduler step and
|
||||
when the token is at position i * block_size.
|
||||
"""
|
||||
|
||||
# Will be set after profiling.
|
||||
num_gpu_blocks: int | None = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for GPU memory."""
|
||||
num_cpu_blocks: int | None = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for CPU memory."""
|
||||
|
||||
kv_sharing_fast_prefill: bool = False
|
||||
"""This feature is work in progress and no prefill optimization takes place
|
||||
with this flag enabled currently.
|
||||
|
||||
In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254),
|
||||
some layers can skip tokens corresponding to prefill. This flag enables
|
||||
attention metadata for eligible layers to be overridden with metadata
|
||||
necessary for implementing this optimization in some models (e.g. Gemma3n)
|
||||
"""
|
||||
|
||||
kv_cache_memory_bytes: int | None = None
|
||||
"""Size of KV Cache per GPU in bytes. By default, this is set to None
|
||||
and vllm can automatically infer the kv cache size based on
|
||||
gpu_memory_utilization. However, users may want to manually specify
|
||||
the kv cache memory size. kv_cache_memory_bytes allows more fine-grain
|
||||
control of how much memory gets used when compared with using
|
||||
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
||||
(when not-None) ignores gpu_memory_utilization"""
|
||||
|
||||
kv_offloading_size: float | None = None
|
||||
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
|
||||
the total buffer size summed across all TP ranks. By default, this is set
|
||||
to None, which means no KV offloading is enabled. When set, vLLM will
|
||||
enable KV cache offloading to CPU using the kv_offloading_backend."""
|
||||
|
||||
kv_offloading_backend: KVOffloadingBackend = "native"
|
||||
"""The backend to use for KV cache offloading. Supported backends include
|
||||
'native' (vLLM native CPU offloading), 'lmcache'.
|
||||
KV offloading is only activated when kv_offloading_size is set."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
ignored_factors = {
|
||||
# Runtime/derived knobs that don't affect compiled graph shape
|
||||
"gpu_memory_utilization",
|
||||
"swap_space",
|
||||
"is_attention_free",
|
||||
"num_gpu_blocks_override",
|
||||
"enable_prefix_caching",
|
||||
"prefix_caching_hash_algo",
|
||||
"cpu_kvcache_space_bytes",
|
||||
"mamba_page_size_padded",
|
||||
# Post-init/derived counters
|
||||
"num_gpu_blocks",
|
||||
"num_cpu_blocks",
|
||||
# WIP feature toggle not impacting compiled graph shape
|
||||
"kv_sharing_fast_prefill",
|
||||
}
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||
# metrics info
|
||||
return {key: str(value) for key, value in self.__dict__.items()}
|
||||
|
||||
@field_validator("cache_dtype", mode="after")
|
||||
@classmethod
|
||||
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
|
||||
if cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor."
|
||||
)
|
||||
return cache_dtype
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> None:
|
||||
swap_space_bytes = math.ceil(self.swap_space * GiB_bytes)
|
||||
total_cpu_memory = get_cpu_memory()
|
||||
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
|
||||
# group are in the same node. However, the GPUs may span multiple nodes.
|
||||
num_gpus_per_node = parallel_config.tensor_parallel_size
|
||||
cpu_memory_usage = swap_space_bytes * num_gpus_per_node
|
||||
|
||||
msg = (
|
||||
f"{format_gib(cpu_memory_usage)} GiB out of the "
|
||||
f"{format_gib(total_cpu_memory)} GiB total CPU memory "
|
||||
"is allocated for the swap space."
|
||||
)
|
||||
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
||||
raise ValueError("Too large swap space. " + msg)
|
||||
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||
logger.warning("Possibly too large swap space. %s", msg)
|
||||
1196
vllm/config/compilation.py
Normal file
1196
vllm/config/compilation.py
Normal file
File diff suppressed because it is too large
Load Diff
73
vllm/config/device.py
Normal file
73
vllm/config/device.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, SkipValidation
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: SkipValidation[Device | torch.device | None] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
on the current platform."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
`__post_init__`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
self.device_type = current_platform.device_type
|
||||
if not self.device_type:
|
||||
raise RuntimeError(
|
||||
"Failed to infer device type, please set "
|
||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||
"to turn on verbose logging to help debug the issue."
|
||||
)
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
if isinstance(self.device, str):
|
||||
self.device_type = self.device
|
||||
elif isinstance(self.device, torch.device):
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
107
vllm/config/ec_transfer.py
Normal file
107
vllm/config/ec_transfer.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
ECProducer = Literal["ec_producer", "ec_both"]
|
||||
ECConsumer = Literal["ec_consumer", "ec_both"]
|
||||
ECRole = Literal[ECProducer, ECConsumer]
|
||||
|
||||
|
||||
@config
|
||||
class ECTransferConfig:
|
||||
"""Configuration for distributed EC cache transfer."""
|
||||
|
||||
ec_connector: str | None = None
|
||||
"""The EC connector for vLLM to transmit EC caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: str | None = None
|
||||
"""The engine id for EC transfers."""
|
||||
|
||||
ec_buffer_device: str | None = "cuda"
|
||||
"""The device used by ec connector to buffer the EC cache.
|
||||
Currently only support 'cuda'."""
|
||||
|
||||
ec_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
ec_role: ECRole | None = None
|
||||
"""Whether this vLLM instance produces, consumes EC cache, or both. Choices
|
||||
are 'ec_producer', 'ec_consumer', 'ec_both'."""
|
||||
|
||||
ec_rank: int | None = None
|
||||
"""The rank of this vLLM instance in the EC cache transfer. Typical value:
|
||||
0 for encoder, 1 for pd instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
ec_parallel_size: int = 1
|
||||
"""The number of parallel instances for EC cache transfer. For
|
||||
PyNcclConnector, this should be 2."""
|
||||
|
||||
ec_ip: str = "127.0.0.1"
|
||||
"""The EC connector ip, used to build distributed connection."""
|
||||
|
||||
ec_port: int = 14579
|
||||
"""The EC connector port, used to build distributed connection."""
|
||||
|
||||
ec_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
ec_connector_module_path: str | None = None
|
||||
"""The Python module path to dynamically load the EC connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.ec_role is not None and self.ec_role not in get_args(ECRole):
|
||||
raise ValueError(
|
||||
f"Unsupported ec_role: {self.ec_role}. "
|
||||
f"Supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
if self.ec_connector is not None and self.ec_role is None:
|
||||
raise ValueError(
|
||||
"Please specify ec_role when ec_connector "
|
||||
f"is set, supported roles are {get_args(ECRole)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_ec_transfer_instance(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECRole)
|
||||
|
||||
@property
|
||||
def is_ec_producer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECProducer)
|
||||
|
||||
@property
|
||||
def is_ec_consumer(self) -> bool:
|
||||
return self.ec_connector is not None and self.ec_role in get_args(ECConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.ec_connector_extra_config.get(key, default)
|
||||
76
vllm/config/kernel.py
Normal file
76
vllm/config/kernel.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
MoEBackend = Literal[
|
||||
"auto",
|
||||
"triton",
|
||||
"deep_gemm",
|
||||
"cutlass",
|
||||
"flashinfer_trtllm",
|
||||
"flashinfer_cutlass",
|
||||
"flashinfer_cutedsl",
|
||||
"marlin",
|
||||
"aiter",
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class KernelConfig:
|
||||
"""Configuration for kernel selection and warmup behavior."""
|
||||
|
||||
enable_flashinfer_autotune: bool = Field(default=None)
|
||||
"""If True, run FlashInfer autotuning during kernel warmup."""
|
||||
|
||||
moe_backend: MoEBackend = "auto"
|
||||
"""Backend for MoE expert computation kernels. Available options:
|
||||
|
||||
- "auto": Automatically select the best backend based on model and hardware\n
|
||||
- "triton": Use Triton-based fused MoE kernels\n
|
||||
- "deep_gemm": Use DeepGEMM kernels (FP8 block-quantized only)\n
|
||||
- "cutlass": Use vLLM CUTLASS kernels\n
|
||||
- "flashinfer_trtllm": Use FlashInfer with TRTLLM-GEN kernels\n
|
||||
- "flashinfer_cutlass": Use FlashInfer with CUTLASS kernels\n
|
||||
- "flashinfer_cutedsl": Use FlashInfer with CuteDSL kernels (FP4 only)\n
|
||||
- "marlin": Use Marlin kernels (weight-only quantization)\n
|
||||
- "aiter": Use AMD AITer kernels (ROCm only)"""
|
||||
|
||||
@field_validator("moe_backend", mode="before")
|
||||
@classmethod
|
||||
def _normalize_moe_backend(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.lower().replace("-", "_")
|
||||
return value
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("enable_flashinfer_autotune", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialization is delayed."""
|
||||
if value is None:
|
||||
return value
|
||||
return handler(value)
|
||||
54
vllm/config/kv_events.py
Normal file
54
vllm/config/kv_events.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
class KVEventsConfig:
|
||||
"""Configuration for KV event publishing."""
|
||||
|
||||
enable_kv_cache_events: bool = False
|
||||
"""If True, enable KV cache events for tracking block storage and removal.
|
||||
Events can be published externally by zmq using the event publisher config.
|
||||
"""
|
||||
|
||||
publisher: Literal["null", "zmq"] = Field(default=None)
|
||||
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
||||
"""
|
||||
|
||||
endpoint: str = "tcp://*:5557"
|
||||
"""The zmq endpoint to use for publishing kv events.
|
||||
"""
|
||||
|
||||
replay_endpoint: str | None = None
|
||||
"""The zmq endpoint to use for replaying kv events.
|
||||
"""
|
||||
|
||||
buffer_steps: int = 10_000
|
||||
"""The number of steps to cache for replay endpoint. Will only save
|
||||
events from the last N steps for the replay endpoint.
|
||||
"""
|
||||
|
||||
hwm: int = 100_000
|
||||
"""The zmq high water mark for the event publisher. After queueing N events,
|
||||
events will start dropping if the consumer is not keeping up.
|
||||
"""
|
||||
|
||||
max_queue_size: int = 100_000
|
||||
"""The maximum number of events to queue while waiting for publishing.
|
||||
"""
|
||||
|
||||
topic: str = ""
|
||||
"""The topic to use for the event publisher. Consumers can subscribe to
|
||||
this topic to receive events.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.publisher is None:
|
||||
self.publisher = "zmq" if self.enable_kv_cache_events else "null"
|
||||
116
vllm/config/kv_transfer.py
Normal file
116
vllm/config/kv_transfer.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import uuid
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
KVProducer = Literal["kv_producer", "kv_both"]
|
||||
KVConsumer = Literal["kv_consumer", "kv_both"]
|
||||
KVRole = Literal[KVProducer, KVConsumer]
|
||||
|
||||
|
||||
@config
|
||||
class KVTransferConfig:
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
kv_connector: str | None = None
|
||||
"""The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
"""
|
||||
|
||||
engine_id: str | None = None
|
||||
"""The engine id for KV transfers."""
|
||||
|
||||
kv_buffer_device: str = "cuda"
|
||||
"""The device used by kv connector to buffer the KV cache. Choices are
|
||||
'cuda' and 'cpu'."""
|
||||
|
||||
kv_buffer_size: float = 1e9
|
||||
"""The buffer size for TorchDistributedConnector. Measured in number of
|
||||
bytes. Recommended value: 1e9 (about 1GB)."""
|
||||
|
||||
kv_role: KVRole | None = None
|
||||
"""Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
are 'kv_producer', 'kv_consumer', and 'kv_both'."""
|
||||
|
||||
kv_rank: int | None = None
|
||||
"""The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
0 for prefill instance, 1 for decode instance.
|
||||
Currently only 1P1D is supported."""
|
||||
|
||||
kv_parallel_size: int = 1
|
||||
"""The number of parallel instances for KV cache transfer. For
|
||||
P2pNcclConnector, this should be 2."""
|
||||
|
||||
kv_ip: str = "127.0.0.1"
|
||||
"""The KV connector ip, used to build distributed connection."""
|
||||
|
||||
kv_port: int = 14579
|
||||
"""The KV connector port, used to build distributed connection."""
|
||||
|
||||
kv_connector_extra_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""any extra config that the connector may need."""
|
||||
|
||||
kv_connector_module_path: str | None = None
|
||||
"""The Python module path to dynamically load the KV connector from.
|
||||
Only supported in V1."""
|
||||
|
||||
enable_permute_local_kv: bool = False
|
||||
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
||||
|
||||
kv_load_failure_policy: Literal["recompute", "fail"] = "fail"
|
||||
"""Policy for handling KV cache load failures.
|
||||
'recompute': reschedule the request to recompute failed blocks
|
||||
'fail': immediately fail the request with an error finish reason (default)"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.engine_id is None:
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in get_args(KVRole):
|
||||
raise ValueError(
|
||||
f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are {get_args(KVRole)}"
|
||||
)
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError(
|
||||
"Please specify kv_role when kv_connector "
|
||||
f"is set, supported roles are {get_args(KVRole)}"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and self.kv_role in get_args(KVRole)
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and self.kv_role in get_args(KVProducer)
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and self.kv_role in get_args(KVConsumer)
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
122
vllm/config/load.py
Normal file
122
vllm/config/load.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
else:
|
||||
LoadFormats = Any
|
||||
TensorizerConfig = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@config
|
||||
class LoadConfig:
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
load_format: str | LoadFormats = "auto"
|
||||
"""The format of the model weights to load:\n
|
||||
- "auto" will try to load the weights in the safetensors format and fall
|
||||
back to the pytorch bin format if safetensors format is not available.\n
|
||||
- "pt" will load the weights in the pytorch bin format.\n
|
||||
- "safetensors" will load the weights in the safetensors format.\n
|
||||
- "npcache" will load the weights in pytorch format and store a numpy cache
|
||||
to speed up the loading.\n
|
||||
- "dummy" will initialize the weights with random values, which is mainly
|
||||
for profiling.\n
|
||||
- "tensorizer" will use CoreWeave's tensorizer library for fast weight
|
||||
loading. See the Tensorize vLLM Model script in the Examples section for
|
||||
more information.\n
|
||||
- "runai_streamer" will load the Safetensors weights using Run:ai Model
|
||||
Streamer.\n
|
||||
- "runai_streamer_sharded" will load weights from pre-sharded checkpoint
|
||||
files using Run:ai Model Streamer.\n
|
||||
- "bitsandbytes" will load the weights using bitsandbytes quantization.\n
|
||||
- "sharded_state" will load weights from pre-sharded checkpoint files,
|
||||
supporting efficient loading of tensor-parallel models.\n
|
||||
- "gguf" will load weights from GGUF format files (details specified in
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||
- "mistral" will load weights from consolidated safetensors files used by
|
||||
Mistral models.
|
||||
- Other custom values can be supported via plugins."""
|
||||
download_dir: str | None = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
safetensors_load_strategy: str = "lazy"
|
||||
"""Specifies the loading strategy for safetensors weights.
|
||||
- "lazy" (default): Weights are memory-mapped from the file. This enables
|
||||
on-demand loading and is highly efficient for models on local storage.
|
||||
- "eager": The entire file is read into CPU memory upfront before loading.
|
||||
This is recommended for models on network filesystems (e.g., Lustre, NFS)
|
||||
as it avoids inefficient random reads, significantly speeding up model
|
||||
initialization. However, it uses more CPU RAM.
|
||||
- "torchao": Weights are loaded in upfront and then reconstructed
|
||||
into torchao tensor subclasses. This is used when the checkpoint
|
||||
was quantized using torchao and saved using safetensors.
|
||||
Needs torchao >= 0.14.0
|
||||
"""
|
||||
model_loader_extra_config: dict | TensorizerConfig = Field(default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format."""
|
||||
device: str | None = None
|
||||
"""Device to which model weights will be loaded, default to
|
||||
device_config.device"""
|
||||
ignore_patterns: list[str] | str = Field(default_factory=lambda: ["original/**/*"])
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
"""Whether to enable tqdm for showing progress bar when loading model
|
||||
weights."""
|
||||
pt_load_map_location: str | dict[str, str] = "cpu"
|
||||
"""
|
||||
pt_load_map_location: the map location for loading pytorch checkpoint, to
|
||||
support loading checkpoints can only be loaded on certain devices like
|
||||
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
|
||||
mapping from different devices like from GPU 1 to GPU 0:
|
||||
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
|
||||
in dictionary needs to be double quoted for json parsing. For more details,
|
||||
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("load_format", mode="after")
|
||||
def _lowercase_load_format(cls, load_format: str) -> str:
|
||||
return load_format.lower()
|
||||
|
||||
@field_validator("ignore_patterns", mode="after")
|
||||
def _validate_ignore_patterns(
|
||||
cls, ignore_patterns: list[str] | str
|
||||
) -> list[str] | str:
|
||||
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
ignore_patterns,
|
||||
)
|
||||
|
||||
return ignore_patterns
|
||||
107
vllm/config/lora.py
Normal file
107
vllm/config/lora.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.cache import CacheConfig
|
||||
else:
|
||||
ModelConfig = Any
|
||||
CacheConfig = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||
MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
|
||||
LoRAExtraVocabSize = Literal[256, 512]
|
||||
|
||||
|
||||
@config(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
max_lora_rank: MaxLoRARanks = 16
|
||||
"""Max LoRA rank."""
|
||||
max_loras: int = Field(default=1, ge=1)
|
||||
"""Max number of LoRAs in a single batch."""
|
||||
fully_sharded_loras: bool = False
|
||||
"""By default, only half of the LoRA computation is sharded with tensor
|
||||
parallelism. Enabling this will use the fully sharded layers. At high
|
||||
sequence length, max rank or tensor parallel size, this is likely faster.
|
||||
"""
|
||||
max_cpu_loras: int | None = None
|
||||
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
|
||||
`max_loras`."""
|
||||
lora_dtype: torch.dtype | LoRADType = "auto"
|
||||
"""Data type for LoRA. If auto, will default to base model dtype."""
|
||||
default_mm_loras: dict[str, str] | None = None
|
||||
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||
is only applicable to multimodal models and should be leveraged when a
|
||||
model always expects a LoRA to be active when a given modality is present.
|
||||
Note that currently, if a request provides multiple additional
|
||||
modalities, each of which have their own LoRA, we do NOT apply
|
||||
default_mm_loras because we currently only support one lora adapter
|
||||
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||
will be automatically assigned to 1-n with the names of the modalities
|
||||
in alphabetic order."""
|
||||
enable_tower_connector_lora: bool = False
|
||||
"""If `True`, LoRA support for the tower (vision encoder) and connector
|
||||
of multimodal models will be enabled. This is an experimental feature and
|
||||
currently only supports some MM models such as the Qwen VL series. The default
|
||||
is False."""
|
||||
specialize_active_lora: bool = False
|
||||
"""Whether to construct lora kernel grid by the number of active LoRA adapters.
|
||||
When set to True, separate cuda graphs will be captured for different counts
|
||||
of active LoRAs (powers of 2 up to max_loras), which can improve performance
|
||||
for variable LoRA usage patterns at the cost of increased startup time and
|
||||
memory usage. Only takes effect when cudagraph_specialize_lora is True.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
factors.append(self.max_lora_rank)
|
||||
factors.append(self.max_loras)
|
||||
factors.append(self.fully_sharded_loras)
|
||||
factors.append(self.lora_dtype)
|
||||
factors.append(self.enable_tower_connector_lora)
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_lora_config(self) -> Self:
|
||||
if self.max_cpu_loras is None:
|
||||
self.max_cpu_loras = self.max_loras
|
||||
elif self.max_cpu_loras < self.max_loras:
|
||||
raise ValueError(
|
||||
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
|
||||
f"max_loras ({self.max_loras})."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.lora_dtype in (None, "auto"):
|
||||
self.lora_dtype = model_config.dtype
|
||||
elif isinstance(self.lora_dtype, str):
|
||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||
2056
vllm/config/model.py
Normal file
2056
vllm/config/model.py
Normal file
File diff suppressed because it is too large
Load Diff
57
vllm/config/model_arch.py
Normal file
57
vllm/config/model_arch.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class ModelArchitectureConfig:
|
||||
"""
|
||||
Configuration for model architecture that required by vLLM runtime
|
||||
"""
|
||||
|
||||
architectures: list[str] | None
|
||||
"""List of model architecture class names (e.g., ['LlamaForCausalLM']).
|
||||
It can be None upon calling `vllm_config.with_hf_config(config.text_config)`"""
|
||||
|
||||
model_type: str
|
||||
"""Model type identifier (e.g., 'llama', 'gpt_oss')."""
|
||||
|
||||
text_model_type: str | None
|
||||
"""Text model type identifier (e.g., 'llama4_text')."""
|
||||
|
||||
hidden_size: int
|
||||
"""Hidden size of the model."""
|
||||
|
||||
total_num_hidden_layers: int
|
||||
"""Number of hidden layers in the model."""
|
||||
|
||||
total_num_attention_heads: int
|
||||
"""Number of attention heads in the model."""
|
||||
|
||||
head_size: int
|
||||
"""Head dimension of the model."""
|
||||
|
||||
vocab_size: int
|
||||
"""Vocabulary size of the model."""
|
||||
|
||||
total_num_kv_heads: int
|
||||
"""Number of key value heads in the model."""
|
||||
|
||||
num_experts: int
|
||||
"""Number of experts in the model."""
|
||||
|
||||
quantization_config: dict[str, Any] | None
|
||||
"""Quantization configuration dictionary containing quantization parameters."""
|
||||
|
||||
is_deepseek_mla: bool
|
||||
"""Whether the model is a DeepSeek MLA model."""
|
||||
|
||||
derived_max_model_len_and_key: tuple[float, str | None]
|
||||
"""Derived maximum model length and key from the hf config."""
|
||||
281
vllm/config/multimodal.py
Normal file
281
vllm/config/multimodal.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, TypeAlias, TypedDict, final
|
||||
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDummyOptions:
|
||||
"""Base options for generating dummy data during profiling."""
|
||||
|
||||
count: int = Field(999, ge=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class VideoDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy video data during profiling."""
|
||||
|
||||
num_frames: int | None = Field(None, gt=0)
|
||||
width: int | None = Field(None, gt=0)
|
||||
height: int | None = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class ImageDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy image data during profiling."""
|
||||
|
||||
width: int | None = Field(None, gt=0)
|
||||
height: int | None = Field(None, gt=0)
|
||||
|
||||
|
||||
@dataclass(config=ConfigDict(extra="forbid"))
|
||||
class AudioDummyOptions(BaseDummyOptions):
|
||||
"""Options for generating dummy audio data during profiling."""
|
||||
|
||||
length: int | None = Field(None, gt=0)
|
||||
|
||||
|
||||
@final
|
||||
class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: ImageDummyOptions
|
||||
"""Options for dummy images."""
|
||||
|
||||
video: VideoDummyOptions
|
||||
"""Options for dummy videos."""
|
||||
|
||||
audio: AudioDummyOptions
|
||||
"""Options for dummy audios."""
|
||||
|
||||
|
||||
MMEncoderTPMode = Literal["weights", "data"]
|
||||
MMCacheType = Literal["shm", "lru"]
|
||||
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type of dummy data.
|
||||
|
||||
The built-in modalities are defined by
|
||||
[`MultiModalDummyOptionsBuiltins`][vllm.config.multimodal.MultiModalDummyOptionsBuiltins].
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
language_model_only: bool = False
|
||||
"""If True, disables all multimodal inputs by setting all modality limits to 0.
|
||||
Equivalent to setting `--limit-mm-per-prompt` to 0 for every modality."""
|
||||
limit_per_prompt: MMDummyOptions = Field(default_factory=dict)
|
||||
"""The maximum number of input items and options allowed per
|
||||
prompt for each modality.
|
||||
|
||||
Defaults to 999 for each modality.
|
||||
|
||||
Legacy format (count only):
|
||||
{"image": 16, "video": 2}
|
||||
|
||||
Configurable format (with options):
|
||||
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
|
||||
"image": {"count": 5, "width": 512, "height": 512}}
|
||||
|
||||
Mixed format (combining both):
|
||||
{"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
|
||||
"height": 512}}
|
||||
"""
|
||||
enable_mm_embeds: bool = False
|
||||
"""If `True`, enables passing multimodal embeddings:
|
||||
for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
|
||||
for the OpenAI-compatible server, this refers to chat messages with content
|
||||
`"type": "*_embeds"`.
|
||||
|
||||
When enabled with `--limit-mm-per-prompt` set to 0 for a modality,
|
||||
precomputed embeddings skip count validation for that modality,
|
||||
saving memory by not loading encoder modules while still enabling
|
||||
embeddings as an input. Limits greater than 0 still apply to embeddings.
|
||||
|
||||
WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
|
||||
Only enable this flag for trusted users!"""
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict)
|
||||
"""Additional args passed to process media inputs, keyed by modalities.
|
||||
For example, to set num_frames for video, set
|
||||
`--media-io-kwargs '{"video": {"num_frames": 40} }'`"""
|
||||
mm_processor_kwargs: dict[str, object] | None = None
|
||||
"""Arguments to be forwarded to the model's processor for multi-modal data,
|
||||
e.g., image processor. Overrides for the multi-modal processor obtained
|
||||
from `transformers.AutoProcessor.from_pretrained`.
|
||||
|
||||
The available overrides depend on the model that is being run.
|
||||
|
||||
For example, for Phi-3-Vision:
|
||||
`{"num_crops": 4}`."""
|
||||
mm_processor_cache_gb: float = Field(default=4, ge=0)
|
||||
"""The size (in GiB) of the multi-modal processor cache, which is used to
|
||||
avoid re-processing past multi-modal inputs.
|
||||
|
||||
This cache is duplicated for each API process and engine core process,
|
||||
resulting in a total memory usage of
|
||||
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
|
||||
|
||||
Set to `0` to disable this cache completely (not recommended)."""
|
||||
mm_processor_cache_type: MMCacheType = "lru"
|
||||
"""Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
|
||||
use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
|
||||
mm_shm_cache_max_object_size_mb: int = Field(default=128, ge=0)
|
||||
"""Size limit (in MiB) for each object stored in the multi-modal processor
|
||||
shared memory cache. Only effective when `mm_processor_cache_type` is
|
||||
`"shm"`."""
|
||||
mm_encoder_only: bool = False
|
||||
"""
|
||||
When enabled, skips the language component of the model.
|
||||
|
||||
This is usually only valid in disaggregated Encoder process.
|
||||
"""
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
|
||||
"""Indicates how to optimize multi-modal encoder inference using tensor
|
||||
parallelism (TP).
|
||||
|
||||
- `"weights"`: Within the same vLLM engine, split the weights of
|
||||
each layer across TP ranks. (default TP behavior)\n
|
||||
- `"data"`: Within the same vLLM engine, split the batched input data
|
||||
across TP ranks to process the data in parallel, while hosting
|
||||
the full weights on each TP rank.
|
||||
This batch-level DP is not to be confused with API request-level
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP."""
|
||||
mm_encoder_attn_backend: AttentionBackendEnum | None = None
|
||||
"""Optional override for the multi-modal encoder attention backend when
|
||||
using vision transformers. Accepts any value from
|
||||
`vllm.v1.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
|
||||
interleave_mm_strings: bool = False
|
||||
"""Enable fully interleaved support for multimodal prompts, while using
|
||||
--chat-template-content-format=string."""
|
||||
skip_mm_profiling: bool = False
|
||||
"""When enabled, skips multimodal memory profiling and only profiles with
|
||||
language backbone model during engine initialization.
|
||||
|
||||
This reduces engine startup time but shifts the responsibility to users for
|
||||
estimating the peak memory usage of the activation of multimodal encoder and
|
||||
embedding cache."""
|
||||
video_pruning_rate: float | None = Field(default=None, ge=0.0, lt=1.0)
|
||||
"""Sets pruning rate for video pruning via Efficient Video Sampling.
|
||||
Value sits in range [0;1) and determines fraction of media tokens
|
||||
from each video to be pruned.
|
||||
"""
|
||||
|
||||
@field_validator("limit_per_prompt", mode="before")
|
||||
@classmethod
|
||||
def _validate_limit_per_prompt(
|
||||
cls,
|
||||
value: dict[str, int | dict[str, int]],
|
||||
) -> MMDummyOptions:
|
||||
out: MMDummyOptions = {}
|
||||
|
||||
for k, v in value.items():
|
||||
# Handle legacy format where only count is specified
|
||||
if isinstance(v, int):
|
||||
v = {"count": v}
|
||||
|
||||
# Convert to the appropriate DummyOptions subclass
|
||||
if k == "video":
|
||||
out[k] = VideoDummyOptions(**v)
|
||||
elif k == "image":
|
||||
out[k] = ImageDummyOptions(**v)
|
||||
elif k == "audio":
|
||||
out[k] = AudioDummyOptions(**v)
|
||||
else:
|
||||
out[k] = BaseDummyOptions(**v)
|
||||
|
||||
return out
|
||||
|
||||
@field_validator("mm_encoder_attn_backend", mode="before")
|
||||
@classmethod
|
||||
def _validate_mm_encoder_attn_backend(
|
||||
cls, value: str | AttentionBackendEnum | None
|
||||
) -> AttentionBackendEnum | None:
|
||||
if isinstance(value, str) and value.upper() == "XFORMERS":
|
||||
raise ValueError(
|
||||
"Attention backend 'XFORMERS' has been removed (See PR #29262 for "
|
||||
"details). Please select a supported attention backend."
|
||||
)
|
||||
|
||||
if value is None or isinstance(value, AttentionBackendEnum):
|
||||
return value
|
||||
|
||||
assert isinstance(value, str), (
|
||||
"mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
|
||||
)
|
||||
return AttentionBackendEnum[value.upper()]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_multimodal_config(self):
|
||||
if self.mm_processor_cache_type != "shm" and (
|
||||
self.mm_shm_cache_max_object_size_mb
|
||||
!= MultiModalConfig.mm_shm_cache_max_object_size_mb
|
||||
):
|
||||
raise ValueError(
|
||||
"'mm_shm_cache_max_object_size_mb' should only be set when "
|
||||
"'mm_processor_cache_type' is 'shm'."
|
||||
)
|
||||
return self
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = [
|
||||
self.mm_encoder_attn_backend.name
|
||||
if self.mm_encoder_attn_backend is not None
|
||||
else None,
|
||||
self.mm_encoder_tp_mode,
|
||||
]
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
"""
|
||||
Get the maximum number of input items allowed per prompt
|
||||
for the given modality (backward compatible).
|
||||
"""
|
||||
if self.language_model_only:
|
||||
return 0
|
||||
|
||||
limit_data = self.limit_per_prompt.get(modality)
|
||||
|
||||
if limit_data is None:
|
||||
# Unspecified modality is set to 999 by default
|
||||
return 999
|
||||
|
||||
return limit_data.count
|
||||
|
||||
def merge_mm_processor_kwargs(
|
||||
self,
|
||||
inference_kwargs: Mapping[str, object],
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Get the keyword arguments to pass to the multi-modal processor
|
||||
according to the extra arguments passed during inference.
|
||||
"""
|
||||
kwargs = self.mm_processor_kwargs or {}
|
||||
return kwargs | dict(inference_kwargs)
|
||||
|
||||
def is_multimodal_pruning_enabled(self):
|
||||
return self.video_pruning_rate is not None and self.video_pruning_rate > 0
|
||||
152
vllm/config/observability.py
Normal file
152
vllm/config/observability.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from packaging.version import parse
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
|
||||
from vllm import version
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@config
|
||||
class ObservabilityConfig:
|
||||
"""Configuration for observability - metrics and tracing."""
|
||||
|
||||
show_hidden_metrics_for_version: str | None = None
|
||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||
specified version. For example, if a previously deprecated metric has been
|
||||
hidden since the v0.7.0 release, you use
|
||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||
you migrate to new metrics. The metric is likely to be removed completely
|
||||
in an upcoming release."""
|
||||
|
||||
@cached_property
|
||||
def show_hidden_metrics(self) -> bool:
|
||||
"""Check if the hidden metrics should be shown."""
|
||||
if self.show_hidden_metrics_for_version is None:
|
||||
return False
|
||||
return version._prev_minor_version_was(self.show_hidden_metrics_for_version)
|
||||
|
||||
otlp_traces_endpoint: str | None = None
|
||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||
|
||||
collect_detailed_traces: list[DetailedTraceModules] | None = None
|
||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||
set, it will collect detailed traces for the specified modules. This
|
||||
involves use of possibly costly and or blocking operations and hence might
|
||||
have a performance impact.
|
||||
|
||||
Note that collecting detailed timing information for each request can be
|
||||
expensive."""
|
||||
|
||||
kv_cache_metrics: bool = False
|
||||
"""Enable KV cache residency metrics (lifetime, idle time, reuse gaps).
|
||||
Uses sampling to minimize overhead.
|
||||
Requires log stats to be enabled (i.e., --disable-log-stats not set)."""
|
||||
|
||||
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
|
||||
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
|
||||
|
||||
cudagraph_metrics: bool = False
|
||||
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
|
||||
dispatch modes, and their observed frequencies at every logging interval)."""
|
||||
|
||||
enable_layerwise_nvtx_tracing: bool = False
|
||||
"""Enable layerwise NVTX tracing. This traces the execution of each layer or
|
||||
module in the model and attach informations such as input/output shapes to
|
||||
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
|
||||
|
||||
enable_mfu_metrics: bool = False
|
||||
"""Enable Model FLOPs Utilization (MFU) metrics."""
|
||||
|
||||
enable_mm_processor_stats: bool = False
|
||||
"""Enable collection of timing statistics for multimodal processor operations.
|
||||
This is for internal use only (e.g., benchmarks) and is not exposed as a CLI
|
||||
argument."""
|
||||
|
||||
enable_logging_iteration_details: bool = False
|
||||
"""Enable detailed logging of iteration details.
|
||||
If set, vllm EngineCore will log iteration details
|
||||
This includes number of context/generation requests and tokens
|
||||
and the elapsed cpu time for the iteration."""
|
||||
|
||||
@cached_property
|
||||
def collect_model_forward_time(self) -> bool:
|
||||
"""Whether to collect model forward time for the request."""
|
||||
return self.collect_detailed_traces is not None and (
|
||||
"model" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def collect_model_execute_time(self) -> bool:
|
||||
"""Whether to collect model execute time for the request."""
|
||||
return self.collect_detailed_traces is not None and (
|
||||
"worker" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces
|
||||
)
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("show_hidden_metrics_for_version")
|
||||
@classmethod
|
||||
def _validate_show_hidden_metrics_for_version(cls, value: str | None) -> str | None:
|
||||
if value is not None:
|
||||
# Raises an exception if the string is not a valid version.
|
||||
parse(value)
|
||||
return value
|
||||
|
||||
@field_validator("otlp_traces_endpoint")
|
||||
@classmethod
|
||||
def _validate_otlp_traces_endpoint(cls, value: str | None) -> str | None:
|
||||
if value is not None:
|
||||
from vllm.tracing import is_tracing_available, otel_import_error_traceback
|
||||
|
||||
if not is_tracing_available():
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}"
|
||||
)
|
||||
return value
|
||||
|
||||
@field_validator("collect_detailed_traces")
|
||||
@classmethod
|
||||
def _validate_collect_detailed_traces(
|
||||
cls, value: list[DetailedTraceModules] | None
|
||||
) -> list[DetailedTraceModules] | None:
|
||||
"""Handle the legacy case where users might provide a comma-separated
|
||||
string instead of a list of strings."""
|
||||
if value is not None and len(value) == 1 and "," in value[0]:
|
||||
value = cast(list[DetailedTraceModules], value[0].split(","))
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_tracing_config(self):
|
||||
if self.collect_detailed_traces and not self.otlp_traces_endpoint:
|
||||
raise ValueError(
|
||||
"collect_detailed_traces requires `--otlp-traces-endpoint` to be set."
|
||||
)
|
||||
return self
|
||||
153
vllm/config/offload.py
Normal file
153
vllm/config/offload.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Configuration for model weight offloading."""
|
||||
|
||||
import warnings
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
OffloadBackend = Literal["auto", "uva", "prefetch"]
|
||||
|
||||
|
||||
@config
|
||||
class UVAOffloadConfig:
|
||||
"""Configuration for UVA (Unified Virtual Addressing) CPU offloading.
|
||||
|
||||
Uses zero-copy access from CPU-pinned memory. Simple but requires
|
||||
fast CPU-GPU interconnect.
|
||||
"""
|
||||
|
||||
cpu_offload_gb: float = Field(default=0, ge=0)
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
increase the GPU memory size. For example, if you have one 24 GB GPU and
|
||||
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
|
||||
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
|
||||
Note that this requires fast CPU-GPU interconnect, as part of the model is
|
||||
loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||
This uses UVA (Unified Virtual Addressing) for zero-copy access.
|
||||
"""
|
||||
|
||||
cpu_offload_params: set[str] = Field(default_factory=set)
|
||||
"""The set of parameter name segments to target for CPU offloading.
|
||||
Unmatched parameters are not offloaded. If this set is empty, parameters
|
||||
are offloaded non-selectively until the memory limit defined by
|
||||
`cpu_offload_gb` is reached.
|
||||
Examples:
|
||||
- For parameter name "mlp.experts.w2_weight":
|
||||
- "experts" or "experts.w2_weight" will match.
|
||||
- "expert" or "w2" will NOT match (must be exact segments).
|
||||
This allows distinguishing parameters like "w2_weight" and "w2_weight_scale".
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
class PrefetchOffloadConfig:
|
||||
"""Configuration for prefetch-based CPU offloading.
|
||||
|
||||
Groups layers and uses async H2D prefetch to hide transfer latency.
|
||||
"""
|
||||
|
||||
offload_group_size: int = Field(default=0, ge=0)
|
||||
"""Group every N layers together. Offload last `offload_num_in_group`
|
||||
layers of each group. Default is 0 (disabled).
|
||||
Example: group_size=8, num_in_group=2 offloads layers 6,7,14,15,22,23,...
|
||||
Unlike cpu_offload_gb, this uses explicit async prefetching to hide transfer
|
||||
latency.
|
||||
"""
|
||||
|
||||
offload_num_in_group: int = Field(default=1, ge=1)
|
||||
"""Number of layers to offload per group.
|
||||
Must be <= offload_group_size. Default is 1."""
|
||||
|
||||
offload_prefetch_step: int = Field(default=1, ge=0)
|
||||
"""Number of layers to prefetch ahead.
|
||||
Higher values hide more latency but use more GPU memory. Default is 1."""
|
||||
|
||||
offload_params: set[str] = Field(default_factory=set)
|
||||
"""The set of parameter name segments to target for prefetch offloading.
|
||||
Unmatched parameters are not offloaded. If this set is empty, ALL
|
||||
parameters of each offloaded layer are offloaded.
|
||||
Uses segment matching: "w13_weight" matches "mlp.experts.w13_weight"
|
||||
but not "mlp.experts.w13_weight_scale".
|
||||
"""
|
||||
|
||||
|
||||
@config
|
||||
class OffloadConfig:
|
||||
"""Configuration for model weight offloading to reduce GPU memory usage."""
|
||||
|
||||
offload_backend: OffloadBackend = "auto"
|
||||
"""The backend for weight offloading. Options:
|
||||
- "auto": Selects based on which sub-config has non-default values
|
||||
(prefetch if offload_group_size > 0, uva if cpu_offload_gb > 0).
|
||||
- "uva": UVA (Unified Virtual Addressing) zero-copy offloading.
|
||||
- "prefetch": Async prefetch with group-based layer offloading.
|
||||
"""
|
||||
|
||||
uva: UVAOffloadConfig = Field(default_factory=UVAOffloadConfig)
|
||||
"""Parameters for UVA offloading backend."""
|
||||
|
||||
prefetch: PrefetchOffloadConfig = Field(default_factory=PrefetchOffloadConfig)
|
||||
"""Parameters for prefetch offloading backend."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_offload_config(self) -> "OffloadConfig":
|
||||
"""Validate offload configuration constraints."""
|
||||
if self.offload_backend == "prefetch" or self.prefetch.offload_group_size > 0:
|
||||
if self.prefetch.offload_num_in_group > self.prefetch.offload_group_size:
|
||||
raise ValueError(
|
||||
f"offload_num_in_group ({self.prefetch.offload_num_in_group})"
|
||||
f" must be <= offload_group_size"
|
||||
f" ({self.prefetch.offload_group_size})"
|
||||
)
|
||||
if self.prefetch.offload_prefetch_step < 1:
|
||||
raise ValueError(
|
||||
f"offload_prefetch_step"
|
||||
f" ({self.prefetch.offload_prefetch_step})"
|
||||
f" must be >= 1 when prefetch offloading is enabled"
|
||||
f" (offload_group_size > 0)"
|
||||
)
|
||||
|
||||
# Warn if both backends have non-default values
|
||||
uva_active = self.uva.cpu_offload_gb > 0
|
||||
prefetch_active = self.prefetch.offload_group_size > 0
|
||||
if self.offload_backend == "uva" and prefetch_active:
|
||||
warnings.warn(
|
||||
"Prefetch offload fields are set but offload_backend='uva'. "
|
||||
"Prefetch settings will be ignored.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif self.offload_backend == "prefetch" and uva_active:
|
||||
warnings.warn(
|
||||
"UVA offload fields are set but offload_backend='prefetch'. "
|
||||
"UVA settings will be ignored.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif self.offload_backend == "auto" and uva_active and prefetch_active:
|
||||
warnings.warn(
|
||||
"Both UVA and prefetch offload fields are set with "
|
||||
"offload_backend='auto'. Prefetch backend will be selected. "
|
||||
"Set offload_backend explicitly to suppress this warning.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return self
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the offload configs.
|
||||
|
||||
All fields are included because PrefetchOffloader patches module
|
||||
forwards and inserts custom ops (wait_prefetch, start_prefetch)
|
||||
into the computation graph. Changing any offload setting can
|
||||
alter which layers are hooked and how prefetch indices are
|
||||
computed, so the compilation cache must distinguish them.
|
||||
"""
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors=set())
|
||||
hash_str = hash_factors(factors)
|
||||
return hash_str
|
||||
713
vllm/config/parallel.py
Normal file
713
vllm/config/parallel.py
Normal file
@@ -0,0 +1,713 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from typing_extensions import Self
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.v1.executor import Executor
|
||||
else:
|
||||
RuntimeEnv = Any
|
||||
PlacementGroup = Any
|
||||
Executor = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ExpertPlacementStrategy = Literal["linear", "round_robin"]
|
||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||
DataParallelBackend = Literal["ray", "mp"]
|
||||
EPLBPolicyOption = Literal["default"]
|
||||
All2AllBackend = Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class EPLBConfig:
|
||||
"""Configuration for Expert Parallel Load Balancing (EP)."""
|
||||
|
||||
window_size: int = 1000
|
||||
"""Window size for expert load recording."""
|
||||
step_interval: int = 3000
|
||||
"""
|
||||
Interval for rearranging experts in expert parallelism.
|
||||
|
||||
Note that if this is greater than the EPLB window size, only the metrics
|
||||
of the last `lb_window_size` steps will be used for rearranging experts.
|
||||
"""
|
||||
|
||||
num_redundant_experts: int = Field(default=0, ge=0)
|
||||
"""Number of redundant experts to use for expert parallelism."""
|
||||
|
||||
log_balancedness: bool = False
|
||||
"""
|
||||
Log the balancedness each step of expert parallelism.
|
||||
This is turned off by default since it will cause communication overhead.
|
||||
"""
|
||||
log_balancedness_interval: int = 1
|
||||
"""
|
||||
Interval for logging the balancedness.
|
||||
"""
|
||||
use_async: bool = False
|
||||
"""
|
||||
Whether to use non-blocking EPLB.
|
||||
"""
|
||||
|
||||
policy: EPLBPolicyOption = "default"
|
||||
"""The policy type for expert parallel load balancing (EPLB)."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_eplb_config(self) -> Self:
|
||||
if self.use_async and self.policy != "default":
|
||||
raise ValueError("Async EPLB is only supported with the default policy.")
|
||||
if self.log_balancedness and self.log_balancedness_interval <= 0:
|
||||
raise ValueError("log_balancedness_interval must be greater than 0.")
|
||||
return self
|
||||
|
||||
|
||||
@config
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution."""
|
||||
|
||||
pipeline_parallel_size: int = 1
|
||||
"""Number of pipeline parallel groups."""
|
||||
tensor_parallel_size: int = 1
|
||||
"""Number of tensor parallel groups."""
|
||||
prefill_context_parallel_size: int = 1
|
||||
"""Number of prefill context parallel groups."""
|
||||
data_parallel_size: int = 1
|
||||
"""Number of data parallel groups. MoE layers will be sharded according to
|
||||
the product of the tensor parallel size and data parallel size."""
|
||||
data_parallel_size_local: int = 1
|
||||
"""Number of local data parallel groups."""
|
||||
data_parallel_rank: int = 0
|
||||
"""Rank of the data parallel group."""
|
||||
data_parallel_rank_local: int | None = None
|
||||
"""Local rank of the data parallel group,
|
||||
set only in SPMD mode."""
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
"""IP of the data parallel master."""
|
||||
data_parallel_rpc_port: int = 29550
|
||||
"""Port for data parallel messaging."""
|
||||
data_parallel_master_port: int = 29500
|
||||
"""Port of the data parallel master."""
|
||||
data_parallel_backend: DataParallelBackend = "mp"
|
||||
"""Backend to use for data parallel, either "mp" or "ray"."""
|
||||
data_parallel_external_lb: bool = False
|
||||
"""Whether to use "external" DP LB mode. Applies only to online serving
|
||||
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
|
||||
wide-EP setup in Kubernetes. Set implicitly when --data-parallel-rank
|
||||
is provided explicitly to vllm serve."""
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
|
||||
and when data_parallel_size > 0. Enables running an AsyncLLM
|
||||
and API server on a "per-node" basis where vLLM load balances
|
||||
between local data parallel ranks, but an external LB balances
|
||||
between vLLM nodes/replicas. Set explicitly in conjunction with
|
||||
--data-parallel-start-rank."""
|
||||
is_moe_model: bool | None = None
|
||||
"""Whether the deployed model is MoE (if known)."""
|
||||
enable_expert_parallel: bool = False
|
||||
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
|
||||
enable_eplb: bool = False
|
||||
"""Enable expert parallelism load balancing for MoE layers."""
|
||||
eplb_config: EPLBConfig = Field(default_factory=EPLBConfig)
|
||||
"""Expert parallelism configuration."""
|
||||
expert_placement_strategy: ExpertPlacementStrategy = "linear"
|
||||
"""The expert placement strategy for MoE layers:\n
|
||||
- "linear": Experts are placed in a contiguous manner. For example, with 4
|
||||
experts and 2 ranks, rank 0 will have experts [0, 1] and rank 1 will have
|
||||
experts [2, 3].\n
|
||||
- "round_robin": Experts are placed in a round-robin manner. For example,
|
||||
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
|
||||
will have experts [1, 3]. This strategy can help improve load balancing
|
||||
for grouped expert models with no redundant experts."""
|
||||
all2all_backend: All2AllBackend = "allgather_reducescatter"
|
||||
"""All2All backend for MoE expert parallel communication. Available options:
|
||||
|
||||
- "naive": Naive all2all implementation using broadcasts\n
|
||||
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
|
||||
- "pplx": Use pplx kernels\n
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels\n
|
||||
- "deepep_low_latency": Use deepep low-latency kernels\n
|
||||
- "mori": Use mori kernels\n
|
||||
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
|
||||
|
||||
max_parallel_loading_workers: int | None = None
|
||||
"""Maximum number of parallel loading workers when loading model
|
||||
sequentially in multiple batches. To avoid RAM OOM when using tensor
|
||||
parallel and large models."""
|
||||
|
||||
disable_custom_all_reduce: bool = False
|
||||
"""Disable the custom all-reduce kernel and fall back to NCCL."""
|
||||
|
||||
enable_dbo: bool = False
|
||||
"""Enable dual batch overlap for the model executor."""
|
||||
ubatch_size: int = 0
|
||||
"""Number of ubatch size."""
|
||||
|
||||
dbo_decode_token_threshold: int = 32
|
||||
"""The threshold for dual batch overlap for batches only containing decodes.
|
||||
If the number of tokens in the request is greater than this threshold,
|
||||
microbatching will be used. Otherwise, the request will be processed in a
|
||||
single batch."""
|
||||
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
|
||||
"""The threshold for dual batch overlap for batches that contain one or more
|
||||
prefills. If the number of tokens in the request is greater than this
|
||||
threshold, microbatching will be used. Otherwise, the request will be
|
||||
processed in a single batch."""
|
||||
|
||||
disable_nccl_for_dp_synchronization: bool | None = Field(default=None)
|
||||
"""Forces the dp synchronization logic in vllm/v1/worker/dp_utils.py
|
||||
to use Gloo instead of NCCL for its all reduce.
|
||||
|
||||
Defaults to True when async scheduling is enabled, False otherwise.
|
||||
"""
|
||||
|
||||
ray_workers_use_nsight: bool = False
|
||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||
|
||||
ray_runtime_env: RuntimeEnv | None = None
|
||||
"""Ray runtime environment to pass to distributed workers."""
|
||||
|
||||
placement_group: PlacementGroup | None = None
|
||||
"""ray distributed model workers placement group."""
|
||||
|
||||
distributed_executor_backend: (
|
||||
str | DistributedExecutorBackend | type[Executor] | None
|
||||
) = None
|
||||
"""Backend to use for distributed model workers, either "ray" or "mp"
|
||||
(multiprocessing). If the product of pipeline_parallel_size and tensor_parallel_size
|
||||
is less than or equal to the number of GPUs available, "mp" will be used to
|
||||
keep processing on a single host. Otherwise, an error will be raised. To use "mp"
|
||||
you must also set nnodes, and to use "ray" you must manually set
|
||||
distributed_executor_backend to "ray".
|
||||
|
||||
Note that tpu only support Ray for distributed inference."""
|
||||
|
||||
worker_cls: str = "auto"
|
||||
"""The full name of the worker class to use. If "auto", the worker class
|
||||
will be determined based on the platform."""
|
||||
sd_worker_cls: str = "auto"
|
||||
"""The full name of the worker class to use for speculative decoding.
|
||||
If "auto", the worker class will be determined based on the platform."""
|
||||
worker_extension_cls: str = ""
|
||||
"""The full name of the worker extension class to use. The worker extension
|
||||
class is dynamically inherited by the worker class. This is used to inject
|
||||
new attributes and methods to the worker class for use in collective_rpc
|
||||
calls."""
|
||||
master_addr: str = "127.0.0.1"
|
||||
"""distributed master address for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
master_port: int = 29501
|
||||
"""distributed master port for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
node_rank: int = 0
|
||||
"""distributed node rank for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
nnodes: int = 1
|
||||
"""num of nodes for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
|
||||
world_size: int = Field(init=False)
|
||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
_data_parallel_master_port_list: list[int] = Field(default_factory=list)
|
||||
"""List of open port auto-queried for data parallel messaging.
|
||||
Set to be private as it's not intended to be configured by users.
|
||||
"""
|
||||
|
||||
decode_context_parallel_size: int = 1
|
||||
"""Number of decode context parallel groups, because the world size does
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
dcp_kv_cache_interleave_size: int = 1
|
||||
"""
|
||||
Interleave size of kv_cache storage while using DCP.
|
||||
dcp_kv_cache_interleave_size has been replaced by cp_kv_cache_interleave_size,
|
||||
and will be deprecated when PCP is fully supported.
|
||||
|
||||
"""
|
||||
cp_kv_cache_interleave_size: int = 1
|
||||
"""Interleave size of kv_cache storage while using DCP or PCP.
|
||||
For `total_cp_rank = pcp_rank * dcp_world_size + dcp_rank`,
|
||||
and `total_cp_world_size = pcp_world_size * dcp_world_size`.
|
||||
store interleave_size tokens on total_cp_rank i,
|
||||
then store next interleave_size tokens on total_cp_rank i+1.
|
||||
Interleave_size=1: token-level alignment, where token `i` is stored on
|
||||
total_cp_rank `i % total_cp_world_size`.
|
||||
Interleave_size=block_size: block-level alignment, where tokens are
|
||||
first populated to the preceding ranks. Tokens are then stored
|
||||
in (rank i+1, block j) only after (rank i, block j) is fully occupied.
|
||||
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
|
||||
Block_size should be divisible by cp_kv_cache_interleave_size.
|
||||
"""
|
||||
|
||||
data_parallel_index: int = Field(init=False)
|
||||
"""Equal to the data parallel rank but not used for torch process groups
|
||||
and not overridden for dense models."""
|
||||
|
||||
_api_process_count: int = Field(default=1, gt=0)
|
||||
"""
|
||||
The number of API processes initialized.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
_api_process_rank: int = Field(default=0, ge=-1)
|
||||
"""
|
||||
The rank of this API process, or `-1` for engine core processes
|
||||
under API server scale-out.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||
return None if value is None else handler(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_parallel_config(self) -> Self:
|
||||
if self._api_process_rank >= self._api_process_count:
|
||||
raise ValueError(
|
||||
"Invalid value of `_api_process_rank`. "
|
||||
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
|
||||
f"but found: {self._api_process_rank}"
|
||||
)
|
||||
|
||||
if self.data_parallel_size_local > self.data_parallel_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_size_local ({self.data_parallel_size_local}) "
|
||||
f"must be <= data_parallel_size ({self.data_parallel_size})"
|
||||
)
|
||||
|
||||
if self.data_parallel_size <= 1 and self.data_parallel_external_lb:
|
||||
raise ValueError(
|
||||
"data_parallel_external_lb can only be set when data_parallel_size > 1"
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
if not current_platform.is_cuda_alike():
|
||||
raise ValueError(
|
||||
"Expert parallelism load balancing is only supported on "
|
||||
"CUDA devices or ROCm devices now."
|
||||
)
|
||||
if not self.enable_expert_parallel:
|
||||
raise ValueError("enable_expert_parallel must be True to use EPLB.")
|
||||
if self.tensor_parallel_size * self.data_parallel_size <= 1:
|
||||
raise ValueError(
|
||||
"EPLB requires tensor_parallel_size or data_parallel_size "
|
||||
f"to be greater than 1, but got "
|
||||
f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}."
|
||||
)
|
||||
else:
|
||||
if self.eplb_config.num_redundant_experts != 0:
|
||||
raise ValueError(
|
||||
"num_redundant_experts is set to "
|
||||
f"{self.eplb_config.num_redundant_experts} but EPLB is not "
|
||||
"enabled. Either enable EPLB or unset "
|
||||
"num_redundant_experts."
|
||||
)
|
||||
|
||||
# Note(hc): In the current implementation of decode context
|
||||
# parallel(DCP), tp_size needs to be divisible by dcp_size,
|
||||
# because the world size does not change by dcp, it simply
|
||||
# reuses the GPUs of TP group, and split one TP group into
|
||||
# tp_size//dcp_size DCP groups.
|
||||
if self.tensor_parallel_size % self.decode_context_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"tp_size={self.tensor_parallel_size} must be divisible by"
|
||||
f"dcp_size={self.decode_context_parallel_size}."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
including data parallelism."""
|
||||
return self.world_size * self.data_parallel_size
|
||||
|
||||
@property
|
||||
def use_ubatching(self) -> bool:
|
||||
return self.enable_dbo or self.ubatch_size > 1
|
||||
|
||||
@property
|
||||
def num_ubatches(self) -> int:
|
||||
return 2 if self.enable_dbo else self.ubatch_size
|
||||
|
||||
@property
|
||||
def local_engines_only(self) -> bool:
|
||||
"""
|
||||
Client manages local+remote EngineCores in pure internal LB case.
|
||||
Client manages local EngineCores in hybrid and external LB case.
|
||||
"""
|
||||
return self.data_parallel_external_lb or self.data_parallel_hybrid_lb
|
||||
|
||||
def get_next_dp_init_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
processes that is related to data parallelism,
|
||||
e.g. both in the worker and in the engine, which
|
||||
can live in different processes. To avoid port conflicts, we
|
||||
pop a new port from the prepared port list each time we need to
|
||||
initialize a new process group related to data parallelism.
|
||||
"""
|
||||
if self._data_parallel_master_port_list:
|
||||
answer = self._data_parallel_master_port_list.pop()
|
||||
else:
|
||||
answer = self.data_parallel_master_port
|
||||
self.data_parallel_master_port += 1
|
||||
|
||||
return answer
|
||||
|
||||
def stateless_init_dp_group(self) -> ProcessGroup:
|
||||
# NOTE: In high-concurrency scenarios multiple processes
|
||||
# can pick the same (currently free) port through a race
|
||||
# condition when calling `get_open_port()`. When the first
|
||||
# process binds the port the others will subsequently fail
|
||||
# with `torch.distributed.DistNetworkError: EADDRINUSE`.
|
||||
# To make the initialization more robust we retry a few times
|
||||
# with a fresh port whenever this specific error is observed.
|
||||
from torch.distributed import DistNetworkError
|
||||
|
||||
from vllm.distributed.utils import (
|
||||
stateless_init_torch_distributed_process_group,
|
||||
)
|
||||
|
||||
max_retries = 5
|
||||
last_exc: Exception | None = None
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# use gloo since the engine process might not have cuda device
|
||||
return stateless_init_torch_distributed_process_group(
|
||||
self.data_parallel_master_ip,
|
||||
self.get_next_dp_init_port(),
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend=current_platform.dist_backend,
|
||||
)
|
||||
except DistNetworkError as e:
|
||||
# We only want to retry when the root cause is EADDRINUSE.
|
||||
if "EADDRINUSE" in str(e):
|
||||
logger.warning("Address already in use. Retrying with a new port.")
|
||||
last_exc = e
|
||||
continue # try again with a new port
|
||||
raise e
|
||||
|
||||
# If we get here all retries have failed.
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
@property
|
||||
def use_sequence_parallel_moe(self) -> bool:
|
||||
return (
|
||||
self.all2all_backend
|
||||
in (
|
||||
"allgather_reducescatter",
|
||||
"naive",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"mori",
|
||||
)
|
||||
and self.enable_expert_parallel
|
||||
and self.tensor_parallel_size > 1
|
||||
and self.data_parallel_size > 1
|
||||
)
|
||||
|
||||
@property
|
||||
def node_rank_within_dp(self) -> int:
|
||||
return self.node_rank % self.nnodes_within_dp
|
||||
|
||||
@property
|
||||
def nnodes_within_dp(self) -> int:
|
||||
if self.nnodes == 1:
|
||||
return 1
|
||||
data_parallel_node_size = (
|
||||
self.data_parallel_size // self.data_parallel_size_local
|
||||
)
|
||||
return self.nnodes // data_parallel_node_size
|
||||
|
||||
@property
|
||||
def local_world_size(self) -> int:
|
||||
return self.world_size // self.nnodes_within_dp
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
|
||||
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
|
||||
# dp rank 0: has_unfinished_seqs=True
|
||||
# dp rank 1: has_unfinished_seqs=False
|
||||
# aggregated: has_unfinished_seqs=True
|
||||
# so this is an OR operation, i.e. MAX in integers
|
||||
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
|
||||
aggregated_has_unfinished = bool(tensor.item())
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@staticmethod
|
||||
def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int:
|
||||
if kv_cache_memory == -1:
|
||||
kv_cache_memory = torch.iinfo(torch.int64).max
|
||||
tensor = torch.tensor([kv_cache_memory], dtype=torch.int64, device="cpu")
|
||||
# we cannot use broadcast for stateless dp group since it depends
|
||||
# on global rank
|
||||
torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group)
|
||||
return tensor.item()
|
||||
|
||||
def compute_hash(self):
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
|
||||
This hash is also used for DP worker configuration validation
|
||||
to prevent hangs from mismatched collective communication patterns.
|
||||
"""
|
||||
ignored_factors = {
|
||||
# Derived/runtime topology, networking, or launch details
|
||||
"data_parallel_rank",
|
||||
"data_parallel_rank_local",
|
||||
"data_parallel_size_local",
|
||||
"data_parallel_index",
|
||||
"data_parallel_backend",
|
||||
"data_parallel_external_lb",
|
||||
"data_parallel_hybrid_lb",
|
||||
"data_parallel_master_ip",
|
||||
"data_parallel_master_port",
|
||||
"_data_parallel_master_port_list",
|
||||
"data_parallel_rpc_port",
|
||||
"rank",
|
||||
"master_addr",
|
||||
"master_port",
|
||||
"node_rank",
|
||||
"nnodes",
|
||||
"max_parallel_loading_workers",
|
||||
"disable_custom_all_reduce",
|
||||
"ray_workers_use_nsight",
|
||||
"ray_runtime_env",
|
||||
"placement_group",
|
||||
"distributed_executor_backend",
|
||||
"worker_cls",
|
||||
"sd_worker_cls",
|
||||
"worker_extension_cls",
|
||||
"_api_process_count",
|
||||
"_api_process_rank",
|
||||
}
|
||||
|
||||
from vllm.config.utils import get_hash_factors, hash_factors
|
||||
|
||||
factors = get_hash_factors(self, ignored_factors)
|
||||
return hash_factors(factors)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Continue with the rest of the initialization
|
||||
self.world_size = (
|
||||
self.pipeline_parallel_size
|
||||
* self.tensor_parallel_size
|
||||
* self.prefill_context_parallel_size
|
||||
)
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
logger.info("Using external launcher for distributed inference.")
|
||||
self.world_size *= self.data_parallel_size
|
||||
|
||||
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
|
||||
# Data parallel was specified in the engine args.
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
# For external launcher,
|
||||
# we need to set the data parallel rank automatically
|
||||
self.data_parallel_rank = int(os.environ["RANK"]) // (
|
||||
self.world_size // self.data_parallel_size
|
||||
)
|
||||
logger.info(
|
||||
"Set data_parallel_rank to %d automatically.",
|
||||
self.data_parallel_rank,
|
||||
)
|
||||
if not self._data_parallel_master_port_list:
|
||||
self._data_parallel_master_port_list = get_open_ports_list(5)
|
||||
self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
|
||||
|
||||
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
|
||||
raise ValueError(
|
||||
f"data_parallel_rank ({self.data_parallel_rank})"
|
||||
f" must be in the range [0, {self.data_parallel_size})"
|
||||
)
|
||||
else:
|
||||
# Otherwise fall back to env vars (e.g. for offline SPMD case).
|
||||
self.data_parallel_size = envs.VLLM_DP_SIZE
|
||||
self.data_parallel_rank = envs.VLLM_DP_RANK
|
||||
self.data_parallel_rank_local = envs.VLLM_DP_RANK_LOCAL
|
||||
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
|
||||
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
|
||||
if self.data_parallel_size > 1 and self.is_moe_model is False:
|
||||
raise ValueError(
|
||||
"Offline data parallel mode is not supported/useful"
|
||||
" for dense models."
|
||||
)
|
||||
|
||||
self.data_parallel_index = self.data_parallel_rank
|
||||
|
||||
if self.distributed_executor_backend == "external_launcher":
|
||||
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
||||
logger.info("Disabling V1 multiprocessing for external launcher.")
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size > 1:
|
||||
# We use multiprocessing by default if world_size fits on the
|
||||
# current node and we aren't in a ray placement group.
|
||||
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
backend: DistributedExecutorBackend = "mp"
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||
backend = "uni"
|
||||
elif current_platform.is_cuda() and self.nnodes > 1:
|
||||
backend = "mp"
|
||||
elif (
|
||||
current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size
|
||||
):
|
||||
gpu_count = cuda_device_count_stateless()
|
||||
raise ValueError(
|
||||
f"World size ({self.world_size}) is larger than the number of "
|
||||
f"available GPUs ({gpu_count}) in this node. If this is "
|
||||
"intentional and you are using:\n"
|
||||
"- ray, set '--distributed-executor-backend ray'.\n"
|
||||
"- multiprocessing, set '--nnodes' appropriately."
|
||||
)
|
||||
elif self.data_parallel_backend == "ray":
|
||||
logger.info(
|
||||
"Using ray distributed inference because "
|
||||
"data_parallel_backend is ray"
|
||||
)
|
||||
backend = "ray"
|
||||
elif ray_found:
|
||||
if self.placement_group:
|
||||
backend = "ray"
|
||||
else:
|
||||
from ray import is_initialized as ray_is_initialized
|
||||
|
||||
if ray_is_initialized():
|
||||
from ray.util import get_current_placement_group
|
||||
|
||||
if get_current_placement_group():
|
||||
backend = "ray"
|
||||
self.distributed_executor_backend = backend
|
||||
logger.debug("Defaulting to use %s for distributed inference", backend)
|
||||
|
||||
if self.distributed_executor_backend is None and self.world_size == 1:
|
||||
self.distributed_executor_backend = "uni"
|
||||
|
||||
if self.max_parallel_loading_workers is not None:
|
||||
logger.warning(
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
allowed_backends = ("mp", "uni", "external_launcher")
|
||||
if (
|
||||
self.distributed_executor_backend not in allowed_backends
|
||||
and self.nnodes > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"nnodes > 1 can only be set when distributed executor "
|
||||
"backend is mp, uni or external_launcher."
|
||||
)
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
isinstance(self.distributed_executor_backend, type)
|
||||
and getattr(self.distributed_executor_backend, "uses_ray", False)
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
# Lazy import to avoid circular import
|
||||
from vllm.v1.executor import Executor
|
||||
|
||||
# Enable batch invariance settings if requested
|
||||
if vllm_is_batch_invariant():
|
||||
self.disable_custom_all_reduce = True
|
||||
|
||||
if (
|
||||
self.distributed_executor_backend is not None
|
||||
and not isinstance(self.distributed_executor_backend, str)
|
||||
and not (
|
||||
isinstance(self.distributed_executor_backend, type)
|
||||
and issubclass(self.distributed_executor_backend, Executor)
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend "
|
||||
f"{self.distributed_executor_backend}. Supported "
|
||||
"values are 'ray', 'mp' 'uni', 'external_launcher', "
|
||||
" custom Executor subclass or its import path."
|
||||
)
|
||||
if self.use_ray:
|
||||
from vllm.v1.executor import ray_utils
|
||||
|
||||
ray_utils.assert_ray_available()
|
||||
|
||||
if not current_platform.use_custom_allreduce():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.debug(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on current platform."
|
||||
)
|
||||
if self.nnodes > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.debug(
|
||||
"Disabled the custom all-reduce since we are running on multi-node."
|
||||
)
|
||||
if self.ray_workers_use_nsight and not self.use_ray:
|
||||
raise ValueError(
|
||||
"Unable to use nsight profiling unless workers run with Ray."
|
||||
)
|
||||
|
||||
return self
|
||||
146
vllm/config/pooler.py
Normal file
146
vllm/config/pooler.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal, get_args
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
|
||||
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)
|
||||
|
||||
TokenPoolingType = Literal["ALL", "STEP"]
|
||||
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
|
||||
|
||||
|
||||
@config
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: SequencePoolingType | TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method used for pooling.
|
||||
|
||||
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
|
||||
with this field. Alternatively, users can set `seq_pooling_type` and
|
||||
`tok_pooling_type` explicitly.
|
||||
|
||||
This field is mainly for user convenience. Internal code should always use
|
||||
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
|
||||
"""
|
||||
|
||||
seq_pooling_type: SequencePoolingType | None = None
|
||||
"""
|
||||
The pooling method used for sequence pooling.
|
||||
"""
|
||||
|
||||
tok_pooling_type: TokenPoolingType | None = None
|
||||
"""
|
||||
The pooling method used for tokenwise pooling.
|
||||
"""
|
||||
|
||||
use_activation: bool | None = None
|
||||
"""
|
||||
Whether to apply activation function to the pooler outputs.
|
||||
`None` uses the pooler's default, which is `True` in most cases.
|
||||
"""
|
||||
|
||||
## for embedding models
|
||||
dimensions: int | None = None
|
||||
"""
|
||||
Reduce the dimensions of embeddings if model
|
||||
support matryoshka representation. Defaults to None.
|
||||
"""
|
||||
enable_chunked_processing: bool = False
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
chunks, processed separately, and then aggregated using weighted averaging.
|
||||
This allows embedding models to handle arbitrarily long text without CUDA
|
||||
errors. Defaults to False.
|
||||
"""
|
||||
max_embed_len: int | None = None
|
||||
"""
|
||||
Maximum input length allowed for embedding generation. When set, allows
|
||||
inputs longer than max_embed_len to be accepted for embedding models.
|
||||
When an input exceeds max_embed_len, it will be handled according to
|
||||
the original max_model_len validation logic.
|
||||
Defaults to None (i.e. set to max_model_len).
|
||||
"""
|
||||
|
||||
## for classification models
|
||||
logit_bias: float | None = None
|
||||
"""
|
||||
If provided, apply classification logit biases. Defaults to None.
|
||||
"""
|
||||
|
||||
## for reward models
|
||||
step_tag_id: int | None = None
|
||||
"""
|
||||
If set, only the score corresponding to the `step_tag_id` in the
|
||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||
are returned.
|
||||
"""
|
||||
returned_token_ids: list[int] | None = None
|
||||
"""
|
||||
A list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of `good_token` and `bad_token` in the
|
||||
`math-shepherd-mistral-7b-prm` model.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if pooling_type := self.pooling_type:
|
||||
if self.seq_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `seq_pooling_type`"
|
||||
)
|
||||
if self.tok_pooling_type is not None:
|
||||
raise ValueError(
|
||||
"Cannot set both `pooling_type` and `tok_pooling_type`"
|
||||
)
|
||||
|
||||
if pooling_type in SEQ_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.seq_pooling_type = pooling_type
|
||||
elif pooling_type in TOK_POOLING_TYPES:
|
||||
logger.debug(
|
||||
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
|
||||
pooling_type,
|
||||
pooling_type,
|
||||
)
|
||||
self.tok_pooling_type = pooling_type
|
||||
else:
|
||||
raise NotImplementedError(pooling_type)
|
||||
|
||||
def get_seq_pooling_type(self) -> SequencePoolingType:
|
||||
assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.seq_pooling_type
|
||||
|
||||
def get_tok_pooling_type(self) -> TokenPoolingType:
|
||||
assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
|
||||
return self.tok_pooling_type
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
124
vllm/config/profiler.py
Normal file
124
vllm/config/profiler.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ProfilerKind = Literal["torch", "cuda"]
|
||||
|
||||
|
||||
def _is_uri_path(path: str) -> bool:
|
||||
"""Check if path is a URI (scheme://...), excluding Windows drive letters.
|
||||
|
||||
Supports custom URI schemes like gs://, s3://, hdfs://, etc.
|
||||
These paths should not be converted to absolute paths.
|
||||
"""
|
||||
if "://" in path:
|
||||
scheme = path.split("://")[0]
|
||||
# Windows drive letters are single characters (e.g., C://)
|
||||
# Valid URI schemes have more than one character
|
||||
return len(scheme) > 1
|
||||
return False
|
||||
|
||||
|
||||
@config
|
||||
class ProfilerConfig:
|
||||
"""Dataclass which contains profiler config for the engine."""
|
||||
|
||||
profiler: ProfilerKind | None = None
|
||||
"""Which profiler to use. Defaults to None. Options are:
|
||||
|
||||
- 'torch': Use PyTorch profiler.\n
|
||||
- 'cuda': Use CUDA profiler."""
|
||||
|
||||
torch_profiler_dir: str = ""
|
||||
"""Directory to save torch profiler traces. Both AsyncLLM's CPU traces and
|
||||
worker's traces (CPU & GPU) will be saved under this directory. Note that
|
||||
it must be an absolute path."""
|
||||
|
||||
torch_profiler_with_stack: bool = True
|
||||
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
|
||||
|
||||
torch_profiler_with_flops: bool = False
|
||||
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_use_gzip: bool = True
|
||||
"""If `True`, saves torch profiler traces in gzip format. Enabled by default"""
|
||||
|
||||
torch_profiler_dump_cuda_time_total: bool = True
|
||||
"""If `True`, dumps total CUDA time in torch profiler traces. Enabled by default."""
|
||||
|
||||
torch_profiler_record_shapes: bool = False
|
||||
"""If `True`, records tensor shapes in the torch profiler. Disabled by default."""
|
||||
|
||||
torch_profiler_with_memory: bool = False
|
||||
"""If `True`, enables memory profiling in the torch profiler.
|
||||
Disabled by default."""
|
||||
|
||||
ignore_frontend: bool = False
|
||||
"""If `True`, disables the front-end profiling of AsyncLLM when using the
|
||||
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
|
||||
since the front-end profiling does not track iterations and will capture the
|
||||
entire range.
|
||||
"""
|
||||
|
||||
delay_iterations: int = Field(default=0, ge=0)
|
||||
"""Number of engine iterations to skip before starting profiling.
|
||||
Defaults to 0, meaning profiling starts immediately after receiving /start_profile.
|
||||
"""
|
||||
|
||||
max_iterations: int = Field(default=0, ge=0)
|
||||
"""Maximum number of engine iterations to profile after starting profiling.
|
||||
Defaults to 0, meaning no limit.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_profiler_config(self) -> Self:
|
||||
has_delay_or_limit = self.delay_iterations > 0 or self.max_iterations > 0
|
||||
if self.profiler == "torch" and has_delay_or_limit and not self.ignore_frontend:
|
||||
logger.warning_once(
|
||||
"Using 'torch' profiler with delay_iterations or max_iterations "
|
||||
"while ignore_frontend is False may result in high overhead."
|
||||
)
|
||||
|
||||
profiler_dir = self.torch_profiler_dir
|
||||
if profiler_dir and self.profiler != "torch":
|
||||
raise ValueError(
|
||||
"torch_profiler_dir is only applicable when profiler is set to 'torch'"
|
||||
)
|
||||
if self.profiler == "torch" and not profiler_dir:
|
||||
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
|
||||
|
||||
# Support any URI scheme (gs://, s3://, hdfs://, etc.)
|
||||
# These paths should not be converted to absolute paths
|
||||
if profiler_dir and not _is_uri_path(profiler_dir):
|
||||
self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
|
||||
|
||||
return self
|
||||
300
vllm/config/scheduler.py
Normal file
300
vllm/config/scheduler.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import InitVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
|
||||
|
||||
@config
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration."""
|
||||
|
||||
max_model_len: InitVar[int]
|
||||
"""Maximum length of a sequence (including prompt and generated text).
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
provide fallbacks and validate other attributes."""
|
||||
|
||||
is_encoder_decoder: InitVar[bool]
|
||||
"""True if the model is an encoder-decoder model.
|
||||
|
||||
Note: This is stored in the ModelConfig, and is used only here to
|
||||
disable chunked prefill and prefix caching for encoder-decoder models.
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
|
||||
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
|
||||
|
||||
runner_type: RunnerType = "generate"
|
||||
"""The runner type to launch for the model."""
|
||||
|
||||
max_num_batched_tokens: int = Field(default=DEFAULT_MAX_NUM_BATCHED_TOKENS, ge=1)
|
||||
"""Maximum number of tokens that can be processed in a single iteration.
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
max_num_scheduled_tokens: int | None = Field(default=None)
|
||||
"""Maximum number of tokens that the scheduler may issue in a single iteration.
|
||||
|
||||
This is usually equal to max_num_batched_tokens, but can be smaller in cases
|
||||
when the model might append tokens into the batch (such as speculative decoding).
|
||||
Defaults to max_num_batched_tokens."""
|
||||
|
||||
max_num_seqs: int = Field(default=DEFAULT_MAX_NUM_SEQS, ge=1)
|
||||
"""Maximum number of sequences to be processed in a single iteration.
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
max_num_partial_prefills: int = Field(default=1, ge=1)
|
||||
"""For chunked prefill, the maximum number of sequences that can be
|
||||
partially prefilled concurrently."""
|
||||
|
||||
max_long_partial_prefills: int = Field(default=1, ge=1)
|
||||
"""For chunked prefill, the maximum number of prompts longer than
|
||||
long_prefill_token_threshold that will be prefilled concurrently. Setting
|
||||
this less than max_num_partial_prefills will allow shorter prompts to jump
|
||||
the queue in front of longer prompts in some cases, improving latency."""
|
||||
|
||||
long_prefill_token_threshold: int = 0
|
||||
"""For chunked prefill, a request is considered long if the prompt is
|
||||
longer than this number of tokens."""
|
||||
|
||||
enable_chunked_prefill: bool = True
|
||||
"""If True, prefill requests can be chunked based
|
||||
on the remaining `max_num_batched_tokens`.
|
||||
|
||||
The default value here is mainly for convenience when testing.
|
||||
In real usage, this should be set in `EngineArgs.create_engine_config`.
|
||||
"""
|
||||
|
||||
is_multimodal_model: bool = False
|
||||
"""True if the model is multimodal."""
|
||||
|
||||
# TODO (ywang96): Make this configurable.
|
||||
max_num_encoder_input_tokens: int = Field(init=False)
|
||||
"""Multimodal encoder compute budget, only used in V1.
|
||||
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||
|
||||
# TODO (ywang96): Make this configurable.
|
||||
encoder_cache_size: int = Field(init=False)
|
||||
"""Multimodal encoder cache size, only used in V1.
|
||||
|
||||
NOTE: This is not currently configurable. It will be overridden by
|
||||
max_num_batched_tokens in case max multimodal embedding size is larger."""
|
||||
|
||||
policy: SchedulerPolicy = "fcfs"
|
||||
"""The scheduling policy to use:\n
|
||||
- "fcfs" means first come first served, i.e. requests are handled in order
|
||||
of arrival.\n
|
||||
- "priority" means requests are handled based on given priority (lower
|
||||
value means earlier handling) and time of arrival deciding any ties)."""
|
||||
|
||||
disable_chunked_mm_input: bool = False
|
||||
"""If set to true and chunked prefill is enabled, we do not want to
|
||||
partially schedule a multimodal item. Only used in V1
|
||||
This ensures that if a request has a mixed prompt
|
||||
(like text tokens TTTT followed by image tokens IIIIIIIIII) where only
|
||||
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
||||
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
|
||||
|
||||
# scheduler class or path. "vllm.v1.core.sched.scheduler.Scheduler"
|
||||
# (default) or "mod.custom_class".
|
||||
scheduler_cls: str | type[object] | None = Field(default=None)
|
||||
"""The scheduler class to use. "vllm.v1.core.sched.scheduler.Scheduler" is
|
||||
the default scheduler. Can be a class directly or the path to a class of
|
||||
form "mod.custom_class"."""
|
||||
|
||||
disable_hybrid_kv_cache_manager: bool | None = None
|
||||
"""If set to True, KV cache manager will allocate the same size of KV cache
|
||||
for all attention layers even if there are multiple type of attention layers
|
||||
like full attention and sliding window attention.
|
||||
If set to None, the default value will be determined based on the environment
|
||||
and starting configuration.
|
||||
"""
|
||||
|
||||
async_scheduling: bool | None = Field(default=None)
|
||||
"""If set to False, disable async scheduling. Async scheduling helps to
|
||||
avoid gaps in GPU utilization, leading to better latency and throughput.
|
||||
"""
|
||||
|
||||
stream_interval: int = Field(default=1, ge=1)
|
||||
"""The interval (or buffer size) for streaming in terms of token length.
|
||||
A smaller value (1) makes streaming smoother by sending each token immediately,
|
||||
while a larger value (e.g., 10) reduces host overhead and may increase throughput
|
||||
by batching multiple tokens before sending."""
|
||||
|
||||
@staticmethod
|
||||
def default_factory(**kwargs):
|
||||
"""
|
||||
Factory method to create `SchedulerConfig` with default values for `InitVar`s.
|
||||
"""
|
||||
if "max_model_len" not in kwargs:
|
||||
kwargs["max_model_len"] = 8192
|
||||
if "is_encoder_decoder" not in kwargs:
|
||||
kwargs["is_encoder_decoder"] = False
|
||||
return SchedulerConfig(**kwargs)
|
||||
|
||||
def get_scheduler_cls(self) -> type["SchedulerInterface"]:
|
||||
if self.scheduler_cls is None:
|
||||
if self.async_scheduling:
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
|
||||
return AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
|
||||
return Scheduler
|
||||
|
||||
# This warning can be removed once the Scheduler interface is
|
||||
# finalized and we can maintain support for scheduler classes that
|
||||
# implement it
|
||||
logger.warning_once(
|
||||
"Using custom scheduler class %s. This scheduler interface is "
|
||||
"not public and compatibility may not be maintained.",
|
||||
self.scheduler_cls,
|
||||
)
|
||||
if not isinstance(self.scheduler_cls, str):
|
||||
return cast(type["SchedulerInterface"], self.scheduler_cls)
|
||||
return resolve_obj_by_qualname(self.scheduler_cls)
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
|
||||
# max_num_batched_tokens need to be included in the hash due
|
||||
# to two reasons:
|
||||
# 1. LoRA creates static buffers based on max_num_batched_tokens.
|
||||
# The tensor sizes and strides get captured in the torch.compile
|
||||
# graph explicitly.
|
||||
# 2. Inductor decides whether using 32-bit or 64-bit indexing integer
|
||||
# based on the data sizes. `max_num_batched_tokens` has an
|
||||
# impact on that. For more details, please check
|
||||
# https://github.com/vllm-project/vllm/issues/29585
|
||||
factors.append(self.max_num_batched_tokens)
|
||||
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@field_validator("scheduler_cls", "async_scheduling", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||
return None if value is None else handler(value)
|
||||
|
||||
def __post_init__(self, max_model_len: int, is_encoder_decoder: bool) -> None:
|
||||
if is_encoder_decoder:
|
||||
# Chunked prefill should be disabled for encoder-decoder models.
|
||||
self.disable_chunked_mm_input = True
|
||||
self.enable_chunked_prefill = False
|
||||
self.long_prefill_token_threshold = 0
|
||||
logger.info(
|
||||
"Encoder-decoder models do not support chunked prefill nor"
|
||||
" prefix caching; disabling both."
|
||||
)
|
||||
|
||||
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
|
||||
self.encoder_cache_size = self.max_num_batched_tokens
|
||||
|
||||
if self.enable_chunked_prefill:
|
||||
logger.info(
|
||||
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(max_model_len * 0.04)
|
||||
|
||||
logger.info(
|
||||
"Concurrent partial prefills enabled with "
|
||||
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
|
||||
"long_prefill_token_threshold=%d",
|
||||
self.max_num_partial_prefills,
|
||||
self.max_long_partial_prefills,
|
||||
self.long_prefill_token_threshold,
|
||||
)
|
||||
|
||||
self.verify_max_model_len(max_model_len)
|
||||
|
||||
def verify_max_model_len(self, max_model_len: int) -> Self:
|
||||
if (
|
||||
self.max_num_batched_tokens < max_model_len
|
||||
and not self.enable_chunked_prefill
|
||||
):
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len."
|
||||
)
|
||||
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs})."
|
||||
)
|
||||
|
||||
if self.max_num_batched_tokens > self.max_num_seqs * max_model_len:
|
||||
logger.warning(
|
||||
"max_num_batched_tokens (%d) exceeds max_num_seqs "
|
||||
"* max_model_len (%d). This may lead to unexpected behavior.",
|
||||
self.max_num_batched_tokens,
|
||||
self.max_num_seqs * max_model_len,
|
||||
)
|
||||
|
||||
if self.max_num_partial_prefills > 1:
|
||||
if not self.enable_chunked_prefill:
|
||||
raise ValueError(
|
||||
"Chunked prefill must be enabled to set "
|
||||
"max_num_partial_prefills > 1."
|
||||
)
|
||||
|
||||
if self.long_prefill_token_threshold > max_model_len:
|
||||
raise ValueError(
|
||||
"long_prefill_token_threshold "
|
||||
f"({self.long_prefill_token_threshold}) cannot be greater "
|
||||
f"than the max_model_len ({max_model_len})."
|
||||
)
|
||||
|
||||
if self.max_long_partial_prefills > self.max_num_partial_prefills:
|
||||
raise ValueError(
|
||||
f"{self.max_long_partial_prefills=} must be less than or equal to "
|
||||
f"{self.max_num_partial_prefills=}."
|
||||
)
|
||||
|
||||
return self
|
||||
789
vllm/config/speculative.py
Normal file
789
vllm/config/speculative.py
Normal file
@@ -0,0 +1,789 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import LoadConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_hf_text_config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.quantization as me_quant
|
||||
else:
|
||||
PretrainedConfig = Any
|
||||
|
||||
me_quant = LazyLoader(
|
||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MTPModelTypes = Literal[
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"glm4_moe_lite_mtp",
|
||||
"glm_ocr_mtp",
|
||||
"ernie_mtp",
|
||||
"nemotron_h_mtp",
|
||||
"exaone_moe_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"qwen3_5_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
"step3p5_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"suffix",
|
||||
EagleModelTypes,
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding."""
|
||||
|
||||
enforce_eager: bool | None = None
|
||||
"""Override the default enforce_eager from model_config"""
|
||||
# General speculative decoding control
|
||||
num_speculative_tokens: int = Field(default=None, gt=0)
|
||||
"""The number of speculative tokens, if provided. It will default to the
|
||||
number in the draft model config if present, otherwise, it is required."""
|
||||
model: str | None = None
|
||||
"""The name of the draft model, eagle head, or additional weights, if
|
||||
provided."""
|
||||
method: SpeculativeMethod | None = None
|
||||
"""The name of the speculative method to use. If users provide and set the
|
||||
`model` param, the speculative method type will be detected automatically
|
||||
if possible, if `model` param is not provided, the method name must be
|
||||
provided.
|
||||
|
||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||
`prompt_lookup_min` should be considered."""
|
||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
tensor_parallel_size: int | None = None
|
||||
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
|
||||
warn users when they mistakenly provide the wrong argument."""
|
||||
|
||||
# Draft model configuration
|
||||
quantization: me_quant.QuantizationMethods | None = None
|
||||
"""Quantization method that was used to quantize the draft model weights.
|
||||
If `None`, we assume the model weights are not quantized. Note that it only
|
||||
takes effect when using the draft model-based speculative method."""
|
||||
max_model_len: int | None = Field(default=None, ge=1)
|
||||
"""The maximum model length of the draft model. Used when testing the
|
||||
ability to skip speculation for some sequences."""
|
||||
revision: str | None = None
|
||||
"""The specific model version to use for the draft model. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use the
|
||||
default version."""
|
||||
code_revision: str | None = None
|
||||
"""The specific revision to use for the draft model code on Hugging Face
|
||||
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
|
||||
will use the default version."""
|
||||
|
||||
# Advanced control
|
||||
disable_padded_drafter_batch: bool = False
|
||||
"""Disable input padding for speculative decoding. If set to True,
|
||||
speculative input batches can contain sequences of different lengths,
|
||||
which may only be supported by certain attention backends. This currently
|
||||
only affects the EAGLE method of speculation."""
|
||||
use_local_argmax_reduction: bool = False
|
||||
"""Use vocab-parallel local argmax instead of all-gathering full logits
|
||||
for draft token generation. Reduces communication from O(vocab_size) to
|
||||
O(2 * tp_size) per token. Only applies to greedy draft selection in
|
||||
non-tree speculation."""
|
||||
|
||||
# Ngram proposer configuration
|
||||
prompt_lookup_max: int | None = Field(default=None, ge=1)
|
||||
"""Maximum size of ngram token window when using Ngram proposer, required
|
||||
when method is set to ngram."""
|
||||
prompt_lookup_min: int | None = Field(default=None, ge=1)
|
||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||
provided. Defaults to 1."""
|
||||
|
||||
# Alternative drafting strategies
|
||||
speculative_token_tree: str | None = None
|
||||
"""Specifies the tree structure for speculative token generation.
|
||||
"""
|
||||
parallel_drafting: bool = False
|
||||
"""Enable parallel drafting, where all speculative tokens are generated
|
||||
in parallel rather than sequentially. This can improve performance but
|
||||
requires the speculative model be trained to support parallel drafting.
|
||||
Only compatible with EAGLE and draft model methods."""
|
||||
|
||||
# required configuration params passed from engine
|
||||
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the target model."""
|
||||
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the target model."""
|
||||
|
||||
# params generated in the post-init stage
|
||||
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the draft model initialized internal."""
|
||||
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the draft model initialized internal."""
|
||||
|
||||
# Suffix decoding configuration
|
||||
suffix_decoding_max_tree_depth: int = 24
|
||||
"""The maximum depth of the suffix decoding global and prompt trees. The
|
||||
tree depth limits the sum of the prefix match and speculation lengths."""
|
||||
|
||||
suffix_decoding_max_cached_requests: int = 10000
|
||||
"""The maximum number of requests to cache in the global suffix tree. If
|
||||
exceeded, will trigger eviction in FIFO order. If set to 0, the global
|
||||
suffix tree is disabled and past responses are not cached (prompt trees
|
||||
are still used)."""
|
||||
|
||||
suffix_decoding_max_spec_factor: float = 1.0
|
||||
"""The maximum spec factor for suffix decoding. The spec factor controls
|
||||
speculation lengths based on the prefix match length: max_spec_tokens =
|
||||
max_spec_factor * prefix_match_length."""
|
||||
|
||||
suffix_decoding_min_token_prob: float = 0.1
|
||||
"""The minimum token probability for suffix decoding. Will only speculate
|
||||
tokens with estimated probability (based on frequency counts) greater than
|
||||
or equal to this value."""
|
||||
|
||||
draft_load_config: LoadConfig | None = None
|
||||
"""Load config for the draft model. If not specified, will use the load
|
||||
config from the target model."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
initial_architecture = hf_config.architectures[0]
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32", "glm_moe_dsa"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
|
||||
)
|
||||
if hf_config.model_type in ("pangu_ultra_moe"):
|
||||
hf_config.model_type = "pangu_ultra_moe_mtp"
|
||||
if hf_config.model_type == "pangu_ultra_moe_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
hf_config.model_type = "mimo_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeLiteForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_lite_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeLiteMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "GlmOcrForConditionalGeneration":
|
||||
hf_config.model_type = "glm_ocr_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["GlmOcrMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "ernie4_5_moe":
|
||||
hf_config.model_type = "ernie_mtp"
|
||||
if hf_config.model_type == "ernie_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
|
||||
)
|
||||
|
||||
if (
|
||||
hf_config.model_type == "nemotron_h"
|
||||
and hasattr(hf_config, "num_nextn_predict_layers")
|
||||
and hf_config.num_nextn_predict_layers > 0
|
||||
):
|
||||
# Check if this is an MTP variant
|
||||
hf_config.model_type = "nemotron_h_mtp"
|
||||
if hf_config.model_type == "nemotron_h_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["NemotronHMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "qwen3_next":
|
||||
hf_config.model_type = "qwen3_next_mtp"
|
||||
if hf_config.model_type == "qwen3_next_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "exaone_moe":
|
||||
hf_config.model_type = "exaone_moe_mtp"
|
||||
if hf_config.model_type == "exaone_moe_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["ExaoneMoeMTP"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type in ("qwen3_5", "qwen3_5_moe"):
|
||||
is_moe = hf_config.model_type == "qwen3_5_moe"
|
||||
hf_config.model_type = "qwen3_5_mtp"
|
||||
n_predict = getattr(hf_config, "mtp_num_hidden_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Qwen3_5MoeMTP" if is_moe else "Qwen3_5MTP"],
|
||||
}
|
||||
)
|
||||
if hf_config.model_type == "longcat_flash":
|
||||
hf_config.model_type = "longcat_flash_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "step3p5":
|
||||
hf_config.model_type = "step3p5_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update({"n_predict": n_predict, "architectures": ["Step3p5MTP"]})
|
||||
|
||||
if initial_architecture == "MistralLarge3ForCausalLM":
|
||||
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
|
||||
|
||||
return hf_config
|
||||
|
||||
def __post_init__(self):
|
||||
# Note: "method" is a new parameter that helps to extend the
|
||||
# configuration of non-model-based proposers, and the "model" parameter
|
||||
# will be used to set the draft model, eagle head, or additional weight
|
||||
# when needed. If users do not specify "method", the speculative method
|
||||
# will be detected automatically if possible. If the speculative method
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
# infer method from user args
|
||||
if self.method is None:
|
||||
if self.model in ("ngram", "[ngram]"):
|
||||
self.method = "ngram"
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
|
||||
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
|
||||
logger.warning(
|
||||
"method `%s` is deprecated and replaced with mtp.", self.method
|
||||
)
|
||||
self.method = "mtp"
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
if self.method == "mtp":
|
||||
if self.target_model_config is None:
|
||||
raise ValueError("target_model_config must be present for mtp")
|
||||
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
|
||||
# FIXME(luccafong): cudagraph with v32 MTP is not supported,
|
||||
# remove this when the issue is fixed.
|
||||
self.enforce_eager = True
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
# --quantization fp8 with a bf16 checkpoint.
|
||||
if not self.quantization:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
)
|
||||
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
# Set default values if not provided
|
||||
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
|
||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||
self.prompt_lookup_min = 5
|
||||
self.prompt_lookup_max = 5
|
||||
elif self.prompt_lookup_min is None:
|
||||
if self.prompt_lookup_max is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_min = self.prompt_lookup_max
|
||||
elif self.prompt_lookup_max is None:
|
||||
if self.prompt_lookup_min is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_max = self.prompt_lookup_min
|
||||
|
||||
# Validate values
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||
f"be <= prompt_lookup_max={self.prompt_lookup_max}"
|
||||
)
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if self.model is not None:
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
||||
spec_target_max_model_len=self.target_model_config.max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
config_format=self.target_model_config.config_format,
|
||||
)
|
||||
|
||||
# Automatically detect the method
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
pass
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
# AngelSlim/Qwen3-8B_eagle3
|
||||
elif "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
self.method = "mlp_speculator"
|
||||
elif self.draft_model_config.hf_config.model_type in get_args(
|
||||
MTPModelTypes
|
||||
):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"Enabling num_speculative_tokens > 1 will run "
|
||||
"multiple times of forward on same MTP layer"
|
||||
",which may result in lower acceptance rate"
|
||||
)
|
||||
elif self.draft_model_config.hf_config.model_type in (
|
||||
"longcat_flash_mtp"
|
||||
):
|
||||
self.method = "longcat_flash_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"LongCat MTP models only have "
|
||||
"one layer. Might need some code changes "
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif self.method == "draft_model":
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported speculative method: '{self.method}'"
|
||||
)
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
if isinstance(
|
||||
self.draft_model_config.hf_config,
|
||||
(EAGLEConfig, SpeculatorsConfig),
|
||||
):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config,
|
||||
method=self.method,
|
||||
model_type="eagle",
|
||||
)
|
||||
# EAGLEConfig primarily updates architectures, so update
|
||||
# all architectures-related fields in draft_model_config
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = (
|
||||
self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = (
|
||||
self.num_speculative_tokens
|
||||
)
|
||||
|
||||
n_predict = getattr(
|
||||
self.draft_model_config.hf_config, "n_predict", None
|
||||
)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif (
|
||||
self.num_speculative_tokens > n_predict
|
||||
and self.num_speculative_tokens % n_predict != 0
|
||||
):
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}"
|
||||
)
|
||||
|
||||
if self.speculative_token_tree is None:
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"A speculative model was provided, but neither "
|
||||
"`speculative_token_tree` nor `num_speculative_tokens` "
|
||||
"was provided"
|
||||
)
|
||||
|
||||
# Generate chain of tokens.
|
||||
self.speculative_token_tree = str(
|
||||
[(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
|
||||
)
|
||||
else:
|
||||
# Sort the token tree breadth-first.
|
||||
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
||||
self.speculative_token_tree = str(
|
||||
sorted(tree_choices, key=lambda t: (len(t), t))
|
||||
)
|
||||
|
||||
self.draft_tensor_parallel_size = (
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config,
|
||||
)
|
||||
)
|
||||
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
)
|
||||
)
|
||||
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config, self.draft_tensor_parallel_size
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def _validate_suffix_decoding(self):
|
||||
if not has_arctic_inference():
|
||||
raise ImportError(
|
||||
"Arctic Inference is required for suffix decoding. "
|
||||
"Install via `pip install arctic-inference==0.1.1`."
|
||||
)
|
||||
if self.num_speculative_tokens is None:
|
||||
# Suffix decoding decides the actual number of speculative tokens
|
||||
# dynamically and treats num_speculative_tokens as a maximum limit.
|
||||
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
|
||||
logger.warning(
|
||||
"Defaulted num_speculative_tokens to %s for suffix decoding.",
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
# Validate values
|
||||
if self.suffix_decoding_max_tree_depth < 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_tree_depth="
|
||||
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
|
||||
)
|
||||
if self.suffix_decoding_max_cached_requests < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_cached_requests="
|
||||
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
|
||||
)
|
||||
if self.suffix_decoding_max_spec_factor < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_spec_factor="
|
||||
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
|
||||
)
|
||||
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_min_token_prob="
|
||||
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len: int | None,
|
||||
draft_max_model_len: int,
|
||||
target_max_model_len: int,
|
||||
) -> int:
|
||||
"""Determine the max sequence len for the draft model. This is usually
|
||||
the draft_max_model_len, but may be the target_max_model_len if it is
|
||||
less than the draft_max_model_len, or may be speculative_max_model_len
|
||||
if it is specified.
|
||||
|
||||
This is necessary so that sequences do not exceed the capacity of the
|
||||
draft model or the target model.
|
||||
|
||||
speculative_max_model_len is mainly used for testing that sequences can
|
||||
skip speculation.
|
||||
"""
|
||||
|
||||
if speculative_max_model_len is not None:
|
||||
if speculative_max_model_len > draft_max_model_len:
|
||||
raise ValueError(
|
||||
f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {draft_max_model_len=}"
|
||||
)
|
||||
|
||||
if speculative_max_model_len > target_max_model_len:
|
||||
raise ValueError(
|
||||
f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {target_max_model_len=}"
|
||||
)
|
||||
|
||||
return speculative_max_model_len
|
||||
|
||||
return min(
|
||||
draft_max_model_len,
|
||||
target_max_model_len,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_and_get_draft_tp(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int | None,
|
||||
draft_hf_config: PretrainedConfig,
|
||||
) -> int:
|
||||
"""
|
||||
Verifies and adjusts the tensor parallel size for a draft model
|
||||
specified using speculative_draft_tensor_parallel_size.
|
||||
"""
|
||||
# If speculative_draft_tensor_parallel_size is unset then set it
|
||||
# appropriately else verify that it is set correctly.
|
||||
if speculative_draft_tensor_parallel_size is None:
|
||||
if draft_hf_config.model_type == "mlp_speculator":
|
||||
speculative_draft_tensor_parallel_size = 1
|
||||
if target_parallel_config.tensor_parallel_size > 1:
|
||||
logger.warning(
|
||||
"%s cannot currently be run with tp>1; "
|
||||
"setting speculative_draft_tensor_parallel_size=1",
|
||||
draft_hf_config.model_type,
|
||||
)
|
||||
else:
|
||||
speculative_draft_tensor_parallel_size = (
|
||||
target_parallel_config.tensor_parallel_size
|
||||
)
|
||||
elif speculative_draft_tensor_parallel_size not in (
|
||||
1,
|
||||
target_parallel_config.tensor_parallel_size,
|
||||
):
|
||||
raise ValueError(
|
||||
f"{speculative_draft_tensor_parallel_size=} cannot be "
|
||||
f"other value than 1 or target model tensor_parallel_size"
|
||||
)
|
||||
return speculative_draft_tensor_parallel_size
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
This is mostly a copy of the target parallel config, except the tp_size.
|
||||
"""
|
||||
draft_parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
|
||||
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
|
||||
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
|
||||
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
|
||||
placement_group=target_parallel_config.placement_group,
|
||||
)
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
if self.tensor_parallel_size is not None:
|
||||
raise ValueError(
|
||||
"'tensor_parallel_size' is not a valid argument in the "
|
||||
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
"speculative model unless the draft model config contains an "
|
||||
"n_predict parameter."
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens <= 0:
|
||||
raise ValueError(
|
||||
"Expected num_speculative_tokens to be greater "
|
||||
f"than zero ({self.num_speculative_tokens})."
|
||||
)
|
||||
|
||||
if self.draft_model_config:
|
||||
self.draft_model_config.verify_with_parallel_config(
|
||||
self.draft_parallel_config
|
||||
)
|
||||
|
||||
eagle3_target_supported = [
|
||||
"llama",
|
||||
"qwen",
|
||||
"minicpm",
|
||||
"gpt_oss",
|
||||
"hunyuan_vl",
|
||||
"hunyuan_v1_dense",
|
||||
"afmoe",
|
||||
"nemotron_h",
|
||||
]
|
||||
if (
|
||||
self.method == "eagle3"
|
||||
and self.target_model_config
|
||||
and not any(
|
||||
supported_model in self.target_model_config.hf_text_config.model_type
|
||||
for supported_model in eagle3_target_supported
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
)
|
||||
self.verify_equal_vocab_size_if_draft_model()
|
||||
return self
|
||||
|
||||
def verify_equal_vocab_size_if_draft_model(self):
|
||||
if (
|
||||
self.method == "draft_model"
|
||||
and self.target_model_config is not None
|
||||
and self.draft_model_config is not None
|
||||
):
|
||||
target_vocab_size = self.target_model_config.get_vocab_size()
|
||||
draft_vocab_size = self.draft_model_config.get_vocab_size()
|
||||
if target_vocab_size != draft_vocab_size:
|
||||
raise ValueError(
|
||||
f"Target and draft model should have the same vocabulary size. "
|
||||
f"Target model vocab_size={target_vocab_size}. "
|
||||
f"Draft model vocab_size={draft_vocab_size}. "
|
||||
f"Using models with different tokenizers can cause out-of-bounds "
|
||||
f"errors during speculative decoding."
|
||||
)
|
||||
|
||||
@property
|
||||
def max_num_new_slots_for_drafting(self) -> int:
|
||||
"""
|
||||
Calculate the maximum number of new slots that might be added to the batch
|
||||
when drafting.
|
||||
"""
|
||||
slots_per_req = 0 # for serial non-draft-model methods, no change needed
|
||||
if self.parallel_drafting:
|
||||
# For parallel drafting, we need one new slot per 'masked' token
|
||||
slots_per_req = self.num_speculative_tokens - 1
|
||||
if self.uses_draft_model():
|
||||
# For draft model-based speculation, we need one new slot per request
|
||||
# Since we do not slice the draft tokens
|
||||
slots_per_req += 1
|
||||
return slots_per_req
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def uses_draft_model(self) -> bool:
|
||||
return self.method == "draft_model"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
39
vllm/config/speech_to_text.py
Normal file
39
vllm/config/speech_to_text.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
sample_rate: float = 16_000
|
||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int | None = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected.
|
||||
`None` means audio duration can be unlimited and won't be chunked."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
splitting long audio. This helps maintain context across chunk boundaries
|
||||
and improves transcription quality at split points."""
|
||||
|
||||
min_energy_split_window_size: int | None = 1600
|
||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||
audio chunks. The algorithm looks for the quietest moment within this
|
||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||
at 16kHz. If None, no chunking will be done."""
|
||||
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return (
|
||||
self.min_energy_split_window_size is not None
|
||||
and self.max_audio_clip_s is not None
|
||||
)
|
||||
76
vllm/config/structured_outputs.py
Normal file
76
vllm/config/structured_outputs.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
|
||||
StructuredOutputsBackend = Literal[
|
||||
"auto", "xgrammar", "guidance", "outlines", "lm-format-enforcer"
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
class StructuredOutputsConfig:
|
||||
"""Dataclass which contains structured outputs config for the engine."""
|
||||
|
||||
backend: StructuredOutputsBackend = "auto"
|
||||
"""Which engine will be used for structured outputs (e.g. JSON schema,
|
||||
regex, etc) by default. With "auto", we will make opinionated choices
|
||||
based on request contents and what the backend libraries currently support,
|
||||
so the behavior is subject to change in each release."""
|
||||
disable_fallback: bool = False
|
||||
"""If `True`, vLLM will not fallback to a different backend on error."""
|
||||
disable_any_whitespace: bool = False
|
||||
"""If `True`, json output will always be compact without any whitespace.
|
||||
If `False`, the model may generate whitespace between JSON fields,
|
||||
which is still valid JSON. This is only supported for xgrammar
|
||||
and guidance backends."""
|
||||
disable_additional_properties: bool = False
|
||||
"""If `True`, the `guidance` backend will not use `additionalProperties`
|
||||
in the JSON schema. This is only supported for the `guidance` backend and
|
||||
is used to better align its behaviour with `outlines` and `xgrammar`."""
|
||||
reasoning_parser: str = ""
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format."""
|
||||
reasoning_parser_plugin: str = ""
|
||||
"""Path to a dynamically reasoning parser plugin that can be dynamically
|
||||
loaded and registered."""
|
||||
enable_in_reasoning: bool = False
|
||||
"""Whether to use structured input for reasoning."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_structured_output_config(self) -> Self:
|
||||
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
|
||||
raise ValueError(
|
||||
"disable_any_whitespace is only supported for "
|
||||
"xgrammar and guidance backends."
|
||||
)
|
||||
if self.disable_additional_properties and self.backend != "guidance":
|
||||
raise ValueError(
|
||||
"disable_additional_properties is only supported "
|
||||
"for the guidance backend."
|
||||
)
|
||||
return self
|
||||
447
vllm/config/utils.py
Normal file
447
vllm/config/utils.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions for vLLM config dataclasses."""
|
||||
|
||||
import ast
|
||||
import enum
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Callable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, field, fields, is_dataclass
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.fields import Field as PydanticField
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import dataclass_transform, runtime_checkable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
else:
|
||||
DataclassInstance = Any
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
|
||||
|
||||
|
||||
@dataclass_transform(field_specifiers=(PydanticField,))
|
||||
def config(
|
||||
cls: type[ConfigT] | None = None,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
|
||||
"""Decorator to create a pydantic dataclass with default config. The default config
|
||||
for the dataclass forbids extra fields.
|
||||
|
||||
All config classes in vLLM should use this decorator.
|
||||
|
||||
Args:
|
||||
cls: The class to decorate
|
||||
config: The pydantic ConfigDict to use. If provided, it will be merged with
|
||||
the default config.
|
||||
**kwargs: Additional arguments to pass to pydantic.dataclass."""
|
||||
# Extra fields are forbidden by default
|
||||
merged_config = ConfigDict(extra="forbid")
|
||||
if config is not None:
|
||||
merged_config.update(config)
|
||||
|
||||
def decorator(cls):
|
||||
return dataclass(cls, config=merged_config, **kwargs)
|
||||
|
||||
# Called with arguments: @config(config=...)
|
||||
if cls is None:
|
||||
return decorator
|
||||
# Called without arguments: @config
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def get_field(cls: ConfigType, name: str) -> Any:
|
||||
"""Get the default factory field of a dataclass by name. Used for getting
|
||||
default factory fields in `EngineArgs`."""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The given class is not a dataclass.")
|
||||
try:
|
||||
named_field = next(f for f in fields(cls) if f.name == name)
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
|
||||
|
||||
# The arguments to copy to the new field
|
||||
default = named_field.default
|
||||
default_factory = named_field.default_factory
|
||||
init = named_field.init
|
||||
|
||||
# Handle pydantic.Field
|
||||
if isinstance(default, FieldInfo):
|
||||
if default.init is not None:
|
||||
init = default.init
|
||||
if default.default_factory is not None:
|
||||
default_factory = cast(Callable[[], Any], default.default_factory)
|
||||
default = MISSING
|
||||
else:
|
||||
default = default.default
|
||||
|
||||
if default is MISSING and default_factory is MISSING:
|
||||
logger.warning_once(
|
||||
"%s.%s has no default or default factory.", cls.__name__, name
|
||||
)
|
||||
return field(default=default, default_factory=default_factory, init=init)
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return get_field(cls, name).init
|
||||
|
||||
|
||||
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
|
||||
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
|
||||
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
|
||||
of `dataclasses.field`"""
|
||||
cls = type(dataclass_instance)
|
||||
dataclass_dict = dataclass_instance.__dict__
|
||||
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
|
||||
dataclass_dict.update(kwargs)
|
||||
return cls(**dataclass_dict)
|
||||
|
||||
|
||||
def getattr_iter(
|
||||
object: object,
|
||||
names: Sequence[str],
|
||||
default: Any | None = None,
|
||||
default_factory: Callable[[], Any] | None = None,
|
||||
warn: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
|
||||
In the case where the first name in `names` is the preferred name, and
|
||||
any other names are deprecated aliases, setting `warn=True` will log a
|
||||
warning when a deprecated name is used.
|
||||
"""
|
||||
for i, name in enumerate(names):
|
||||
if hasattr(object, name):
|
||||
if warn and i > 0:
|
||||
logger.warning_once(
|
||||
"%s contains a deprecated attribute name '%s'. "
|
||||
"Please use the preferred attribute name '%s' instead.",
|
||||
type(object).__name__,
|
||||
name,
|
||||
names[0],
|
||||
)
|
||||
return getattr(object, name)
|
||||
return default_factory() if default_factory is not None else default
|
||||
|
||||
|
||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
"""
|
||||
Get any docstrings placed after attribute assignments in a class body.
|
||||
|
||||
https://davidism.com/mit-license/
|
||||
"""
|
||||
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
|
||||
if not isinstance(cls_node, ast.ClassDef):
|
||||
raise TypeError("Given object was not a class.")
|
||||
|
||||
out = {}
|
||||
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (
|
||||
not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)
|
||||
):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
|
||||
# An assignment can have multiple targets (a = b = v), but an
|
||||
# annotated assignment only has one target.
|
||||
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||
|
||||
for target in targets:
|
||||
# Must be assigning to a plain name.
|
||||
if not isinstance(target, ast.Name):
|
||||
continue
|
||||
|
||||
out[target.id] = doc
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsHash(Protocol):
|
||||
def compute_hash(self) -> str: ...
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
def metrics_info(self) -> dict[str, str]: ...
|
||||
|
||||
|
||||
def update_config(config: ConfigT, overrides: dict[str, Any]) -> ConfigT:
|
||||
processed_overrides = {}
|
||||
for field_name, value in overrides.items():
|
||||
assert hasattr(config, field_name), (
|
||||
f"{type(config)} has no field `{field_name}`"
|
||||
)
|
||||
current_value = getattr(config, field_name)
|
||||
if is_dataclass(current_value) and not is_dataclass(value):
|
||||
assert isinstance(value, dict), (
|
||||
f"Overrides to {type(config)}.{field_name} must be a dict"
|
||||
f" or {type(current_value)}, but got {type(value)}"
|
||||
)
|
||||
value = update_config(
|
||||
current_value, # type: ignore[type-var]
|
||||
value,
|
||||
)
|
||||
processed_overrides[field_name] = value
|
||||
return replace(config, **processed_overrides)
|
||||
|
||||
|
||||
def normalize_value(x):
|
||||
"""Return a stable, JSON-serializable canonical form for hashing.
|
||||
Order: primitives, special types (Enum, callable, torch.dtype, Path), then
|
||||
generic containers (Mapping/Set/Sequence) with recursion.
|
||||
"""
|
||||
# Fast path
|
||||
if x is None or isinstance(x, (bool, int, float, str)):
|
||||
return x
|
||||
|
||||
# Enums: tag with FQN to avoid primitive collisions.
|
||||
# Ex: Enum(1) vs int(1) -> ("module.QualName", value).
|
||||
if isinstance(x, enum.Enum):
|
||||
enum_type = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
return (enum_type, normalize_value(x.value))
|
||||
|
||||
# Classes (types) are accepted and canonicalized by their fully-qualified
|
||||
# name (module.qualname) for a stable identifier.
|
||||
# Instances are only accepted if they expose uuid(); otherwise they are
|
||||
# rejected to avoid under-hashing object state.
|
||||
|
||||
# Callables: accept classes only; reject funcs/lambdas/methods.
|
||||
# Used by LogitsProcessor types and ModelConfig.hf_overrides.
|
||||
if isinstance(x, type):
|
||||
module = getattr(x, "__module__", "")
|
||||
qual = getattr(x, "__qualname__", getattr(x, "__name__", ""))
|
||||
return ".".join([p for p in (module, qual) if p]) or repr(x)
|
||||
|
||||
# Prefer stable uuid identifiers for objects that provide them, even if
|
||||
# they are callable instances (e.g., InductorPass wrappers).
|
||||
if hasattr(x, "uuid") and callable(getattr(x, "uuid", None)):
|
||||
return x.uuid()
|
||||
|
||||
if callable(x):
|
||||
raise TypeError("normalize_value: function or callable instance unsupported")
|
||||
|
||||
# Torch dtype: stringify (torch.float64 -> "torch.float64").
|
||||
# We rely on the string form here; dtype-bearing fields that need additional
|
||||
# disambiguation should encode that at the config layer.
|
||||
if isinstance(x, torch.dtype):
|
||||
return str(x)
|
||||
|
||||
# Bytes
|
||||
if isinstance(x, (bytes, bytearray)):
|
||||
return x.hex()
|
||||
|
||||
# Paths (canonicalize)
|
||||
if isinstance(x, pathlib.Path):
|
||||
try:
|
||||
return str(x.expanduser().resolve())
|
||||
except Exception:
|
||||
return str(x)
|
||||
|
||||
# Dataclasses: represent as (FQN, sorted(field,value) tuple) for stability.
|
||||
if is_dataclass(x):
|
||||
type_fqn = f"{x.__class__.__module__}.{x.__class__.__qualname__}"
|
||||
items = tuple(
|
||||
(f.name, normalize_value(getattr(x, f.name)))
|
||||
for f in sorted(fields(x), key=lambda f: f.name)
|
||||
)
|
||||
return (type_fqn, items)
|
||||
|
||||
# Containers (generic)
|
||||
if isinstance(x, Mapping):
|
||||
return tuple(sorted((str(k), normalize_value(v)) for k, v in x.items()))
|
||||
if isinstance(x, Set):
|
||||
return tuple(sorted(repr(normalize_value(v)) for v in x))
|
||||
if isinstance(x, Sequence) and not isinstance(x, (str, bytes, bytearray)):
|
||||
return tuple(normalize_value(v) for v in x)
|
||||
|
||||
# PretrainedConfig
|
||||
if hasattr(x, "to_json_string") and callable(x.to_json_string):
|
||||
return x.to_json_string()
|
||||
|
||||
# Unsupported type: e.g., modules, generators, open files, or objects
|
||||
# without a stable JSON/UUID representation. Hard-error to avoid
|
||||
# under-hashing.
|
||||
# If you hit this, either reshape your config to use supported primitives
|
||||
# and containers, or extend normalize_value to provide a stable encoding
|
||||
# (e.g., via uuid() or to_json_string()) for this type.
|
||||
raise TypeError(
|
||||
f"normalize_value: unsupported type '{type(x).__name__}'. "
|
||||
"Ensure config values use supported primitives/containers or add a "
|
||||
"stable representation for this type."
|
||||
)
|
||||
|
||||
|
||||
def get_hash_factors(config: ConfigT, ignored_factors: set[str]) -> dict[str, object]:
|
||||
"""Gets the factors used for hashing a config class.
|
||||
- Includes all dataclass fields not in `ignored_factors`.
|
||||
- Errors on non-normalizable values.
|
||||
"""
|
||||
factors: dict[str, object] = {}
|
||||
for dc_field in fields(config):
|
||||
factor = dc_field.name
|
||||
if factor in ignored_factors:
|
||||
continue
|
||||
value = getattr(config, factor, None)
|
||||
try:
|
||||
factors[factor] = normalize_value(value)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
f"get_hash_factors: unsupported type for key '{factor}' "
|
||||
f"({type(value).__name__})"
|
||||
) from e
|
||||
return factors
|
||||
|
||||
|
||||
def hash_factors(items: dict[str, object]) -> str:
|
||||
"""Return a SHA-256 hex digest of the canonical items structure."""
|
||||
return hashlib.sha256(json.dumps(items, sort_keys=True).encode()).hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
"""
|
||||
A range of numbers.
|
||||
Inclusive of start, inclusive of end.
|
||||
"""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def is_single_size(self) -> bool:
|
||||
return self.start == self.end
|
||||
|
||||
def __contains__(self, size: int) -> bool:
|
||||
# Inclusive of start, inclusive of end
|
||||
return self.start <= size <= self.end
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Range):
|
||||
return False
|
||||
return self.start == other.start and self.end == other.end
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.start, self.end))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"({self.start}, {self.end})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
def handle_deprecated(
|
||||
config: ConfigT,
|
||||
old_name: str,
|
||||
new_name_or_names: str | list[str],
|
||||
removal_version: str,
|
||||
) -> None:
|
||||
old_val = getattr(config, old_name)
|
||||
if old_val is None:
|
||||
return
|
||||
|
||||
if isinstance(new_name_or_names, str):
|
||||
new_names = [new_name_or_names]
|
||||
else:
|
||||
new_names = new_name_or_names
|
||||
|
||||
msg = (
|
||||
f"{old_name} is deprecated and will be removed in {removal_version}. "
|
||||
f"Use {', '.join(new_names)} instead."
|
||||
)
|
||||
logger.warning(msg)
|
||||
|
||||
for new_name in new_names:
|
||||
setattr(config, new_name, old_val)
|
||||
|
||||
|
||||
def get_from_deprecated_env_if_set(
|
||||
env_name: str,
|
||||
removal_version: str,
|
||||
field_name: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Get value from deprecated environment variable with warning.
|
||||
|
||||
Args:
|
||||
env_name: Name of the deprecated environment variable
|
||||
removal_version: Version when it will be removed
|
||||
field_name: Name of the field to suggest as alternative
|
||||
|
||||
Returns:
|
||||
The environment variable value if set, None otherwise
|
||||
"""
|
||||
if envs.is_set(env_name):
|
||||
value = os.environ.get(env_name)
|
||||
alt_msg = f" Please use {field_name} instead." if field_name else ""
|
||||
logger.warning_once(
|
||||
"Using %s environment variable is deprecated and will be removed in %s.%s",
|
||||
env_name,
|
||||
removal_version,
|
||||
alt_msg,
|
||||
)
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def set_from_deprecated_env_if_set(
|
||||
config: ConfigT,
|
||||
env_name: str,
|
||||
removal_version: str,
|
||||
field_name: str,
|
||||
to_bool: bool = False,
|
||||
to_int: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Set object field from deprecated environment variable with warning.
|
||||
|
||||
Args:
|
||||
config: Config object to set the field on
|
||||
env_name: Name of the deprecated environment variable
|
||||
removal_version: Version when the env var will be removed
|
||||
field_name: Name of the field to set
|
||||
to_bool: Whether to convert the environment variable value to boolean
|
||||
to_int: Whether to convert the environment variable value to integer
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if to_bool and to_int:
|
||||
raise ValueError("Cannot convert to both boolean and integer.")
|
||||
|
||||
env_value = get_from_deprecated_env_if_set(env_name, removal_version, field_name)
|
||||
if env_value is not None:
|
||||
field_value: str | bool | int = env_value
|
||||
if to_bool:
|
||||
field_value = env_value.lower() in ("1", "true")
|
||||
elif to_int:
|
||||
field_value = int(env_value)
|
||||
setattr(config, field_name, field_value)
|
||||
1758
vllm/config/vllm.py
Normal file
1758
vllm/config/vllm.py
Normal file
File diff suppressed because it is too large
Load Diff
13
vllm/config/weight_transfer.py
Normal file
13
vllm/config/weight_transfer.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Literal
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
class WeightTransferConfig:
|
||||
"""Configuration for weight transfer during RL training."""
|
||||
|
||||
backend: Literal["nccl"] = "nccl"
|
||||
"""The backend to use for weight transfer."""
|
||||
189
vllm/connections.py
Normal file
189
vllm/connections.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from urllib3.util import parse_url
|
||||
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
|
||||
class HTTPConnection:
|
||||
"""Helper class to send HTTP requests."""
|
||||
|
||||
def __init__(self, *, reuse_client: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.reuse_client = reuse_client
|
||||
|
||||
self._sync_client: requests.Session | None = None
|
||||
self._async_client: aiohttp.ClientSession | None = None
|
||||
|
||||
def get_sync_client(self) -> requests.Session:
|
||||
if self._sync_client is None or not self.reuse_client:
|
||||
self._sync_client = requests.Session()
|
||||
|
||||
return self._sync_client
|
||||
|
||||
# NOTE: We intentionally use an async function even though it is not
|
||||
# required, so that the client is only accessible inside async event loop
|
||||
async def get_async_client(self) -> aiohttp.ClientSession:
|
||||
if self._async_client is None or not self.reuse_client:
|
||||
self._async_client = aiohttp.ClientSession(trust_env=True)
|
||||
|
||||
return self._async_client
|
||||
|
||||
def _validate_http_url(self, url: str):
|
||||
parsed_url = parse_url(url)
|
||||
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
"Invalid HTTP URL: A valid HTTP URL must have scheme 'http' or 'https'."
|
||||
)
|
||||
|
||||
def _headers(self, **extras: str) -> MutableMapping[str, str]:
|
||||
return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras}
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
stream: bool = False,
|
||||
timeout: float | None = None,
|
||||
extra_headers: Mapping[str, str] | None = None,
|
||||
allow_redirects: bool = True,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
client = self.get_sync_client()
|
||||
extra_headers = extra_headers or {}
|
||||
|
||||
return client.get(
|
||||
url,
|
||||
headers=self._headers(**extra_headers),
|
||||
stream=stream,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
|
||||
async def get_async_response(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
extra_headers: Mapping[str, str] | None = None,
|
||||
allow_redirects: bool = True,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
client = await self.get_async_client()
|
||||
extra_headers = extra_headers or {}
|
||||
|
||||
return client.get(
|
||||
url,
|
||||
headers=self._headers(**extra_headers),
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects,
|
||||
)
|
||||
|
||||
def get_bytes(
|
||||
self, url: str, *, timeout: float | None = None, allow_redirects: bool = True
|
||||
) -> bytes:
|
||||
with self.get_response(
|
||||
url, timeout=timeout, allow_redirects=allow_redirects
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.content
|
||||
|
||||
async def async_get_bytes(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
allow_redirects: bool = True,
|
||||
) -> bytes:
|
||||
async with await self.get_async_response(
|
||||
url, timeout=timeout, allow_redirects=allow_redirects
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.read()
|
||||
|
||||
def get_text(self, url: str, *, timeout: float | None = None) -> str:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.text
|
||||
|
||||
async def async_get_text(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
) -> str:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.text()
|
||||
|
||||
def get_json(self, url: str, *, timeout: float | None = None) -> str:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
|
||||
async def async_get_json(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
) -> str:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.json()
|
||||
|
||||
def download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: Path,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
chunk_size: int = 128,
|
||||
) -> Path:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
with save_path.open("wb") as f:
|
||||
for chunk in r.iter_content(chunk_size):
|
||||
f.write(chunk)
|
||||
|
||||
return save_path
|
||||
|
||||
async def async_download_file(
|
||||
self,
|
||||
url: str,
|
||||
save_path: Path,
|
||||
*,
|
||||
timeout: float | None = None,
|
||||
chunk_size: int = 128,
|
||||
) -> Path:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
with save_path.open("wb") as f:
|
||||
async for chunk in r.content.iter_chunked(chunk_size):
|
||||
f.write(chunk)
|
||||
|
||||
return save_path
|
||||
|
||||
|
||||
global_http_connection = HTTPConnection()
|
||||
"""
|
||||
The global [`HTTPConnection`][vllm.connections.HTTPConnection] instance used
|
||||
by vLLM.
|
||||
"""
|
||||
0
vllm/device_allocator/__init__.py
Normal file
0
vllm/device_allocator/__init__.py
Normal file
301
vllm/device_allocator/cumem.py
Normal file
301
vllm/device_allocator/cumem.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# cumem-based pytorch pluggable allocator to implement sleep mode.
|
||||
# other approaches tried but failed:
|
||||
# - cuda-python package binding
|
||||
# - custom libcuda driver ctypes wrapper
|
||||
# both of them failed because of cuda context mismatch.
|
||||
# not sure why, they are created from a different context.
|
||||
# the only successful approach is to call cuda driver API in C.
|
||||
import dataclasses
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.system_utils import find_loaded_library
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
cumem_available = False
|
||||
try:
|
||||
from vllm.cumem_allocator import (
|
||||
init_module,
|
||||
python_create_and_map,
|
||||
python_unmap_and_release,
|
||||
)
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
|
||||
lib_name = find_loaded_library("cumem_allocator")
|
||||
libcudart = CudaRTLibrary()
|
||||
cumem_available = True
|
||||
except ModuleNotFoundError:
|
||||
# only cuda and rocm platforms support cumem allocator
|
||||
init_module = None
|
||||
python_create_and_map = None
|
||||
python_unmap_and_release = None
|
||||
CudaRTLibrary = None
|
||||
lib_name = None
|
||||
libcudart = None
|
||||
|
||||
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
|
||||
HandleType = tuple[int, int, int, int]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AllocationData:
|
||||
handle: HandleType
|
||||
tag: str
|
||||
cpu_backup_tensor: torch.Tensor | None = None
|
||||
|
||||
|
||||
def create_and_map(allocation_handle: HandleType) -> None:
|
||||
python_create_and_map(*allocation_handle)
|
||||
|
||||
|
||||
def unmap_and_release(allocation_handle: HandleType) -> None:
|
||||
python_unmap_and_release(*allocation_handle)
|
||||
|
||||
|
||||
def get_pluggable_allocator(
|
||||
python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
|
||||
) -> torch.cuda.memory.CUDAPluggableAllocator:
|
||||
init_module(python_malloc_fn, python_free_func)
|
||||
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
lib_name, "my_malloc", "my_free"
|
||||
)
|
||||
return new_alloc
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool_with_allocator(
|
||||
python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
|
||||
) -> None:
|
||||
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
||||
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
|
||||
with torch.cuda.memory.use_mem_pool(mem_pool):
|
||||
yield mem_pool, new_alloc
|
||||
|
||||
|
||||
class CuMemAllocator:
|
||||
"""
|
||||
A singleton class that manages a memory pool for CUDA tensors.
|
||||
The memory in this pool can be offloaded or discarded when the
|
||||
allocator sleeps.
|
||||
|
||||
Inside the `use_memory_pool(tag)` context, all tensors created will
|
||||
be allocated in the memory pool, and has the same tag as the
|
||||
tag passed to the context.
|
||||
|
||||
When we call `sleep`, all tensors with the specified tag will be
|
||||
offloaded to CPU memory, and the rest of the tensors will be discarded.
|
||||
When we call `wake_up`, all tensors that are previously offloaded
|
||||
will be loaded back to GPU memory, and the rest of the tensors will
|
||||
have empty memory.
|
||||
|
||||
Why it needs to be a singleton?
|
||||
When allocated tensors are garbage collected, PyTorch will call
|
||||
the free callback, which will call the `python_free_callback` method.
|
||||
The C-extension uses a global variable to store the function of an
|
||||
instance of this class. If we create multiple instances of this class,
|
||||
the global variable will be overwritten and the free callback will
|
||||
not work as expected.
|
||||
"""
|
||||
|
||||
instance: "CuMemAllocator" = None
|
||||
default_tag: str = "default"
|
||||
|
||||
@staticmethod
|
||||
def get_instance() -> "CuMemAllocator":
|
||||
"""
|
||||
CuMemAllocator is a singleton class.
|
||||
We cannot call the constructor directly.
|
||||
Call this method to get the instance.
|
||||
"""
|
||||
assert cumem_available, "cumem allocator is not available"
|
||||
if CuMemAllocator.instance is None:
|
||||
CuMemAllocator.instance = CuMemAllocator()
|
||||
return CuMemAllocator.instance
|
||||
|
||||
def __init__(self):
|
||||
conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
|
||||
assert "expandable_segments:True" not in conf, (
|
||||
"Expandable segments are not compatible with memory pool. "
|
||||
"Please track https://github.com/pytorch/pytorch/issues/147851 "
|
||||
"for the latest updates."
|
||||
)
|
||||
|
||||
self.pointer_to_data: dict[int, AllocationData] = {}
|
||||
self.current_tag: str = CuMemAllocator.default_tag
|
||||
self.allocator_and_pools: dict[str, Any] = {}
|
||||
# Creating strong references to the two callbacks here to prevent
|
||||
# these ephemeral bound-method objects being garbage collected.
|
||||
# See discussions in https://github.com/vllm-project/vllm/pull/22724
|
||||
self.python_malloc_callback = self._python_malloc_callback
|
||||
self.python_free_callback = self._python_free_callback
|
||||
|
||||
def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||
"""
|
||||
Internal method to store the allocation data
|
||||
when memory is allocated in the memory pool."""
|
||||
py_d_mem = allocation_handle[2]
|
||||
self.pointer_to_data[py_d_mem] = AllocationData(
|
||||
allocation_handle, self.current_tag
|
||||
)
|
||||
logger.debug(
|
||||
"Allocated %s bytes for %s with address %s from cumem allocator",
|
||||
allocation_handle[1],
|
||||
self.current_tag,
|
||||
py_d_mem,
|
||||
)
|
||||
return
|
||||
|
||||
def _python_free_callback(self, ptr: int) -> HandleType:
|
||||
"""
|
||||
Internal method to look up the allocation data
|
||||
when memory is freed in the memory pool."""
|
||||
data = self.pointer_to_data.pop(ptr)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
data.cpu_backup_tensor = None
|
||||
logger.debug(
|
||||
"Freed %s bytes for %s with address %s from cumem allocator",
|
||||
data.handle[1],
|
||||
data.tag,
|
||||
ptr,
|
||||
)
|
||||
return data.handle
|
||||
|
||||
def sleep(self, offload_tags: tuple[str, ...] | str | None = None) -> None:
|
||||
"""
|
||||
Put the allocator in sleep mode.
|
||||
All data in the memory allocation with the specified tag will be
|
||||
offloaded to CPU memory, and others will be discarded.
|
||||
|
||||
:param offload_tags: The tags of the memory allocation that will be
|
||||
offloaded. The rest of the memory allocation will be discarded.
|
||||
"""
|
||||
if offload_tags is None:
|
||||
# by default, allocated tensors are offloaded
|
||||
# when the allocator sleeps
|
||||
offload_tags = (CuMemAllocator.default_tag,)
|
||||
elif isinstance(offload_tags, str):
|
||||
offload_tags = (offload_tags,)
|
||||
|
||||
assert isinstance(offload_tags, tuple)
|
||||
|
||||
total_bytes = 0
|
||||
backup_bytes = 0
|
||||
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
total_bytes += handle[1]
|
||||
if data.tag in offload_tags:
|
||||
backup_bytes += handle[1]
|
||||
size_in_bytes = handle[1]
|
||||
cpu_backup_tensor = torch.empty(
|
||||
size_in_bytes,
|
||||
dtype=torch.uint8,
|
||||
device="cpu",
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
libcudart.cudaMemcpy(cpu_ptr, ptr, size_in_bytes)
|
||||
data.cpu_backup_tensor = cpu_backup_tensor
|
||||
unmap_and_release(handle)
|
||||
|
||||
logger.info(
|
||||
"CuMemAllocator: sleep freed %.2f GiB memory in total, of which "
|
||||
"%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
|
||||
"directly.",
|
||||
total_bytes / 1024**3,
|
||||
backup_bytes / 1024**3,
|
||||
(total_bytes - backup_bytes) / 1024**3,
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
"""
|
||||
Wake up the allocator from sleep mode.
|
||||
All data that is previously offloaded will be loaded back to GPU
|
||||
memory, and the rest of the data will have empty memory.
|
||||
|
||||
:param tags: The tags of the memory allocation that will be loaded
|
||||
back to GPU memory. If None, all memory allocation will be loaded
|
||||
back to GPU memory.
|
||||
"""
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
if tags is None or data.tag in tags:
|
||||
handle = data.handle
|
||||
create_and_map(handle)
|
||||
if data.cpu_backup_tensor is not None:
|
||||
cpu_backup_tensor = data.cpu_backup_tensor
|
||||
if cpu_backup_tensor is not None:
|
||||
size_in_bytes = (
|
||||
cpu_backup_tensor.numel() * cpu_backup_tensor.element_size()
|
||||
)
|
||||
cpu_ptr = cpu_backup_tensor.data_ptr()
|
||||
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
|
||||
data.cpu_backup_tensor = None
|
||||
|
||||
@contextmanager
|
||||
def use_memory_pool(self, tag: str | None = None):
|
||||
"""
|
||||
A context manager to use the memory pool.
|
||||
All memory allocation created inside the context will be allocated
|
||||
in the memory pool, and has the specified tag.
|
||||
|
||||
:param tag: The tag of the memory allocation. If None, the default tag
|
||||
will be used.
|
||||
"""
|
||||
if tag is None:
|
||||
tag = CuMemAllocator.default_tag
|
||||
|
||||
assert isinstance(tag, str)
|
||||
|
||||
old_tag = self.current_tag
|
||||
self.current_tag = tag
|
||||
with use_memory_pool_with_allocator(
|
||||
self.python_malloc_callback, self.python_free_callback
|
||||
) as data:
|
||||
# start to hit another PyTorch bug in PyTorch 2.6,
|
||||
# possibly because of gc-related issue w.r.t. the allocator and
|
||||
# the memory pool.
|
||||
# to avoid the issue, we keep a reference of the data.
|
||||
# see https://github.com/pytorch/pytorch/issues/146431 .
|
||||
self.allocator_and_pools[tag] = data
|
||||
yield
|
||||
# PyTorch's bug, calling torch.cuda.empty_cache() will error
|
||||
# when using pluggable allocator, see
|
||||
# https://github.com/pytorch/pytorch/issues/145168 .
|
||||
# if we have some memory allocated and then freed,
|
||||
# the memory will not be released, e.g. in online quantization,
|
||||
# where the model is created in higher precision, and then
|
||||
# quantized in lower precision.
|
||||
# Find all unused allocations and manually release them.
|
||||
# TODO: we should expose `empty_cache` method in the memory pool.
|
||||
# TODO: ask for help from PyTorch team to expose this method.
|
||||
allocations = data[0].snapshot()
|
||||
for allocation in allocations:
|
||||
if allocation["allocated_size"] == 0:
|
||||
handle = self._python_free_callback(allocation["address"])
|
||||
unmap_and_release(handle)
|
||||
self.current_tag = old_tag
|
||||
|
||||
def get_current_usage(self) -> int:
|
||||
"""
|
||||
Get the total number of bytes allocated in the memory pool.
|
||||
"""
|
||||
sum_bytes: int = 0
|
||||
for ptr, data in self.pointer_to_data.items():
|
||||
handle = data.handle
|
||||
sum_bytes += handle[1]
|
||||
return sum_bytes
|
||||
6
vllm/distributed/__init__.py
Normal file
6
vllm/distributed/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .communication_op import *
|
||||
from .parallel_state import *
|
||||
from .utils import *
|
||||
43
vllm/distributed/communication_op.py
Normal file
43
vllm/distributed/communication_op.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from .parallel_state import get_tp_group
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
return get_tp_group().all_reduce(input_)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(
|
||||
input_: torch.Tensor, dim: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
return get_tp_group().all_gather(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_reduce_scatter(
|
||||
input_: torch.Tensor, dim: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""Reduce-Scatter the input tensor across model parallel group."""
|
||||
return get_tp_group().reduce_scatter(input_, dim)
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(
|
||||
input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
"""Gather the input tensor across model parallel group."""
|
||||
return get_tp_group().gather(input_, dst, dim)
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: dict[Any, torch.Tensor | Any] | None = None, src: int = 0
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
|
||||
0
vllm/distributed/device_communicators/__init__.py
Normal file
0
vllm/distributed/device_communicators/__init__.py
Normal file
696
vllm/distributed/device_communicators/all2all.py
Normal file
696
vllm/distributed/device_communicators/all2all.py
Normal file
@@ -0,0 +1,696 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
from flashinfer.comm import Mapping # type: ignore[import-not-found]
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
|
||||
from flashinfer.comm.trtllm_alltoall import (
|
||||
MnnvlMoe, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NaiveAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
A naive implementation of all2all communication.
|
||||
It uses all-reduce under the hood, which is not
|
||||
efficient at all. The main purpose is for testing and
|
||||
debugging.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool,
|
||||
) -> torch.Tensor:
|
||||
assert len(x.shape) == 2
|
||||
buffer = torch.empty(
|
||||
(cu_tokens_across_sp_cpu[-1], x.size(1)), device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
world_size = self.world_size if is_sequence_parallel else self.dp_world_size
|
||||
|
||||
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
||||
end = cu_tokens_across_sp_cpu[idx]
|
||||
get_ep_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_weights = self.naive_multicast(
|
||||
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_ids = self.naive_multicast(
|
||||
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[ep_rank]
|
||||
|
||||
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
An implementation of all2all communication based on
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
|
||||
tensors_to_gather = [hidden_states, router_logits]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
if extra_tensors is not None:
|
||||
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||
return gathered_tensors[0], gathered_tensors[1]
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
|
||||
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
hidden_states = gathered_tensors[0]
|
||||
topk_weights = gathered_tensors[1]
|
||||
topk_ids = gathered_tensors[2]
|
||||
|
||||
if extra_tensors is None:
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_pplx(), (
|
||||
"pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install pplx_kernels."
|
||||
)
|
||||
super().__init__(cpu_group)
|
||||
|
||||
if self.internode:
|
||||
# inter-node communication needs nvshmem,
|
||||
# intra-node communication uses p2p mapping directly
|
||||
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
|
||||
nvshmem_alloc_empty_unique_id,
|
||||
nvshmem_get_unique_id,
|
||||
nvshmem_init,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
uid = (
|
||||
nvshmem_get_unique_id()
|
||||
if self.rank == 0
|
||||
else nvshmem_alloc_empty_unique_id()
|
||||
)
|
||||
dist.broadcast(
|
||||
uid,
|
||||
src=dist.get_process_group_ranks(self.cpu_group)[0],
|
||||
group=self.cpu_group,
|
||||
)
|
||||
logger.debug("PPLX NVSHMEM UID = %s", uid)
|
||||
nvshmem_init(uid, self.rank, self.world_size)
|
||||
|
||||
self.handle_cache = Cache()
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import pplx_kernels as pplx # type: ignore[import-not-found]
|
||||
|
||||
return self.handle_cache.get_or_create(
|
||||
kwargs,
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
with self.handle_cache._lock:
|
||||
for _, handle in self.handle_cache._cache.items():
|
||||
handle.destroy()
|
||||
|
||||
if self.internode:
|
||||
from pplx_kernels.nvshmem import (
|
||||
nvshmem_finalize, # type: ignore[import-not-found]
|
||||
)
|
||||
|
||||
logger.debug("PPLX NVSHMEM finalize")
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_deep_ep(), (
|
||||
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
|
||||
" to install DeepEP kernels."
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(self) -> dict[Any, Any]:
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_rdma_bytes = None
|
||||
num_qps_per_rank = None
|
||||
|
||||
if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
|
||||
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = self.num_sms // 2
|
||||
else:
|
||||
num_rdma_bytes = 0
|
||||
num_qps_per_rank = 1
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
"DeepEPHTAll2AllManager expects no arguments. All the required "
|
||||
"args are computed in the Manager itself."
|
||||
)
|
||||
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs()
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer
|
||||
)
|
||||
return handle
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Right now the buffers are sized for only what the kernels were
|
||||
# created with. So we can only reduce the number of SMS used
|
||||
# but not increase it.
|
||||
if num_sms > self.num_sms:
|
||||
num_sms = self.num_sms
|
||||
deep_ep.Buffer.set_num_sms(num_sms)
|
||||
|
||||
|
||||
class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on DeepEP Low-Latency kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
token_hidden_size: int,
|
||||
num_ep_ranks: int,
|
||||
num_global_experts: int,
|
||||
num_local_experts: int,
|
||||
) -> dict[Any, Any]:
|
||||
"""
|
||||
max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank
|
||||
can dispatch all the ranks must hold the same value.
|
||||
token_hidden_size: the hidden dimension of each token.
|
||||
num_ep_ranks: the number of EP group ranks.
|
||||
num_global_experts: Number of experts in the model.
|
||||
num_local_experts: Number of experts in an EP rank.
|
||||
"""
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts,
|
||||
)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
)
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
The kwargs for DeepEPLLAll2AllManager is dictated by
|
||||
_make_all2all_kwargs.
|
||||
"""
|
||||
import deep_ep # type: ignore[import-not-found]
|
||||
|
||||
buffer_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("DeepEP all2all args %s", buffer_kwargs)
|
||||
handle: deep_ep.Buffer = self.handle_cache.get_or_create(
|
||||
buffer_kwargs, deep_ep.Buffer
|
||||
)
|
||||
return handle
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> int | None:
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
"""
|
||||
|
||||
# This type lint could be removed after all of the work in
|
||||
# https://github.com/vllm-project/vllm/issues/26533 done.
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_flashinfer_all2all(), (
|
||||
"flashinfer all2all module not found. Please install/check flashinfer"
|
||||
) # noqa
|
||||
super().__init__(cpu_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All rank=%d, world size=%d",
|
||||
self.rank,
|
||||
self.world_size,
|
||||
)
|
||||
self.initialized = False
|
||||
self.alltoall_info = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
gpus_per_node: int,
|
||||
):
|
||||
"""Initialize workspace"""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
logger.debug("making map: rank=%d, world size=%d", rank, world_size)
|
||||
self.mapping = Mapping(
|
||||
world_size,
|
||||
rank,
|
||||
gpus_per_node,
|
||||
tp_size=world_size,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.mnnvl_compat import (
|
||||
CustomCommunicator,
|
||||
)
|
||||
|
||||
dp_config = MnnvlConfig(
|
||||
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
||||
fabric_page_size=1 << 29, # 512MB
|
||||
allocation_granularity=0, # Auto-detect
|
||||
)
|
||||
|
||||
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(self.mapping, dp_config)
|
||||
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
|
||||
self.mapping, dp_config
|
||||
)
|
||||
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.initialized = True
|
||||
|
||||
logger.info(
|
||||
"FlashInfer All2All initialized for rank %s, size %s", rank, world_size
|
||||
)
|
||||
|
||||
def ensure_alltoall_workspace_initialized(self):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not has_flashinfer_all2all():
|
||||
return False
|
||||
|
||||
if self.world_size <= 1:
|
||||
return False
|
||||
|
||||
if not self.initialized:
|
||||
self.initialize(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
gpus_per_node=torch.cuda.device_count,
|
||||
)
|
||||
return self.initialized
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up workspace"""
|
||||
if (
|
||||
self.initialized
|
||||
and self.workspace_tensor is not None
|
||||
and self.prepare_workspace_tensor is not None
|
||||
):
|
||||
try:
|
||||
del self.workspace_tensor
|
||||
del self.prepare_workspace_tensor
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
|
||||
finally:
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
|
||||
|
||||
class MoriAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
assert has_mori(), (
|
||||
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
|
||||
" to install MoRI kernels."
|
||||
) # noqa
|
||||
import mori
|
||||
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
|
||||
mori.shmem.shmem_torch_process_group_init("mori")
|
||||
|
||||
def _make_all2all_kwargs(
|
||||
self,
|
||||
rank: int,
|
||||
num_ep_ranks: int,
|
||||
input_dtype: torch.dtype,
|
||||
quant_dtype: torch.dtype,
|
||||
token_hidden_size: int,
|
||||
scale_dim: int,
|
||||
scale_type_size: int,
|
||||
max_num_tokens_per_dp_rank: int,
|
||||
num_local_experts: int,
|
||||
num_experts_per_token: int,
|
||||
):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
from vllm.platforms.rocm import on_gfx942, on_gfx950
|
||||
|
||||
assert on_gfx942() or on_gfx950(), (
|
||||
"mori currently only support arch gfx942 and gfx950"
|
||||
)
|
||||
|
||||
if not self.internode:
|
||||
# single node
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
|
||||
rdma_block_num = 0
|
||||
warp_num_per_block = 16
|
||||
block_num = 80
|
||||
else:
|
||||
# multi node
|
||||
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
|
||||
if on_gfx942():
|
||||
warp_num_per_block = 16
|
||||
block_num = 32
|
||||
rdma_block_num = 16
|
||||
elif on_gfx950():
|
||||
warp_num_per_block = 8
|
||||
block_num = 64
|
||||
rdma_block_num = 32
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"mori currently only support arch gfx942 and gfx950"
|
||||
)
|
||||
|
||||
return dict(
|
||||
rank=rank,
|
||||
world_size=num_ep_ranks,
|
||||
data_type=quant_dtype,
|
||||
hidden_dim=token_hidden_size,
|
||||
scale_dim=scale_dim,
|
||||
scale_type_size=scale_type_size,
|
||||
max_token_type_size=input_dtype.itemsize,
|
||||
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
|
||||
num_experts_per_rank=num_local_experts,
|
||||
num_experts_per_token=num_experts_per_token,
|
||||
warp_num_per_block=warp_num_per_block,
|
||||
block_num=block_num,
|
||||
kernel_type=kernel_type,
|
||||
rdma_block_num=rdma_block_num,
|
||||
gpu_per_node=min(8, num_ep_ranks),
|
||||
)
|
||||
|
||||
def _make_handle(self, **kwargs):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
|
||||
handle = mori.ops.EpDispatchCombineOp(mori_config)
|
||||
return handle
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
import mori # type: ignore[import-not-found]
|
||||
|
||||
mori_kwargs = self._make_all2all_kwargs(**kwargs)
|
||||
logger.debug("MoRI all2all args %s", mori_kwargs)
|
||||
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
|
||||
mori_kwargs, self._make_handle
|
||||
)
|
||||
return handle
|
||||
344
vllm/distributed/device_communicators/all_reduce_utils.py
Normal file
344
vllm/distributed/device_communicators/all_reduce_utils.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from collections.abc import Sequence
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MiB = 1024 * 1024
|
||||
# Max size for each world size in case symmetric memory is available
|
||||
# For different SM architectures
|
||||
CUSTOM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: MiB // 2, # 512 KB
|
||||
8: MiB // 4, # 256 KB
|
||||
},
|
||||
"10.0": {
|
||||
2: 2 * MiB, # 2 MB
|
||||
4: 2 * MiB, # 2 MB
|
||||
6: 1 * MiB, # 1 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
},
|
||||
}
|
||||
|
||||
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
||||
"9.0": {
|
||||
2: 64 * MiB, # 64 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 64 * MiB, # 64 MB
|
||||
8: 64 * MiB, # 64 MB
|
||||
},
|
||||
"10.0": {
|
||||
2: 8 * MiB, # 8 MB
|
||||
4: 32 * MiB, # 32 MB
|
||||
6: 128 * MiB, # 128 MB
|
||||
8: 128 * MiB, # 128 MB
|
||||
},
|
||||
}
|
||||
|
||||
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
|
||||
"min_world_size": 4,
|
||||
"thresholds": {
|
||||
4: 2 * MiB, # 2 MB
|
||||
8: 1 * MiB, # 1 MB
|
||||
},
|
||||
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
|
||||
}
|
||||
|
||||
|
||||
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
return False
|
||||
|
||||
if not is_symmetric_memory_enabled():
|
||||
return False
|
||||
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
|
||||
return False
|
||||
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
|
||||
if threshold is not None and input_tensor.nbytes >= threshold:
|
||||
return True
|
||||
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
|
||||
|
||||
|
||||
def producer(
|
||||
batch_src: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: str | None = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for i in batch_src:
|
||||
lib.cudaSetDevice(i)
|
||||
pointer = lib.cudaMalloc(1024)
|
||||
lib.cudaMemset(pointer, 1, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
handle = lib.cudaIpcGetMemHandle(pointer)
|
||||
producer_queue.put(handle)
|
||||
open_success = consumer_queue.get()
|
||||
if open_success:
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.put(0)
|
||||
consumer_queue.get()
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def consumer(
|
||||
batch_tgt: Sequence[int],
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices: str | None = None,
|
||||
):
|
||||
if cuda_visible_devices is not None:
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
for j in batch_tgt:
|
||||
lib.cudaSetDevice(j)
|
||||
handle = producer_queue.get()
|
||||
open_success = False
|
||||
try:
|
||||
pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore
|
||||
open_success = True
|
||||
except RuntimeError:
|
||||
# cannot error out here, because the producer process
|
||||
# is still waiting for the response.
|
||||
pass
|
||||
consumer_queue.put(open_success)
|
||||
if open_success:
|
||||
# modify the memory
|
||||
lib.cudaMemset(pointer, 2, 1024)
|
||||
lib.cudaDeviceSynchronize()
|
||||
# use two queues to simulate barrier
|
||||
producer_queue.get()
|
||||
consumer_queue.put(0)
|
||||
# check if the memory is modified
|
||||
host_data = (ctypes.c_char * 1024)()
|
||||
lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore
|
||||
for i in range(1024):
|
||||
if ord(host_data[i]) != 2:
|
||||
open_success = False
|
||||
break
|
||||
result_queue.put(open_success)
|
||||
lib.cudaDeviceReset()
|
||||
|
||||
|
||||
def can_actually_p2p(
|
||||
batch_src: Sequence[int],
|
||||
batch_tgt: Sequence[int],
|
||||
) -> Sequence[bool]:
|
||||
"""
|
||||
Usually, checking if P2P access is enabled can be done by
|
||||
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
|
||||
the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)`
|
||||
returns `True` even if P2P access is not actually possible.
|
||||
See https://github.com/vllm-project/vllm/issues/2728 and
|
||||
https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10
|
||||
Therefore, we have to perform a real P2P access to check if it is actually
|
||||
possible.
|
||||
|
||||
Note on p2p and cuda IPC:
|
||||
Usually, one process uses one GPU:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|
||||
We need to combine p2p and cuda IPC, so that:
|
||||
GPU src --> cuda context src --> tensor src --> process src
|
||||
|shared|
|
||||
GPU tgt --> cuda context tgt --> tensor tgt --> process tgt
|
||||
That is to say, process src creates a tensor in GPU src, passes IPC handle to
|
||||
process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the
|
||||
tensor in process tgt will be reflected in the tensor in process src, because
|
||||
they are the same memory segment.
|
||||
It is important to note that process tgt accesses the tensor in GPU tgt, not
|
||||
GPU src. That's why we need p2p access.
|
||||
|
||||
The most time-consuming part is the process creation. To avoid creating
|
||||
processes for every pair of GPUs, we use batched testing. We create two
|
||||
processes for testing all pairs of GPUs in batch. The trick is to reset
|
||||
the device after each test (which is not available in PyTorch).
|
||||
""" # noqa
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
# pass the CUDA_VISIBLE_DEVICES to the child process
|
||||
# to make sure they see the same set of GPUs
|
||||
|
||||
# make sure the processes are spawned
|
||||
smp = mp.get_context("spawn")
|
||||
producer_queue = smp.Queue()
|
||||
consumer_queue = smp.Queue()
|
||||
result_queue = smp.Queue()
|
||||
p_src = smp.Process(
|
||||
target=producer,
|
||||
args=(
|
||||
batch_src,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_tgt = smp.Process(
|
||||
target=consumer,
|
||||
args=(
|
||||
batch_tgt,
|
||||
producer_queue,
|
||||
consumer_queue,
|
||||
result_queue,
|
||||
cuda_visible_devices,
|
||||
),
|
||||
)
|
||||
p_src.start()
|
||||
p_tgt.start()
|
||||
p_src.join()
|
||||
p_tgt.join()
|
||||
assert p_src.exitcode == 0 and p_tgt.exitcode == 0
|
||||
result: list[bool] = []
|
||||
for src, tgt in zip(batch_src, batch_tgt):
|
||||
a = result_queue.get()
|
||||
b = result_queue.get()
|
||||
if a != b:
|
||||
logger.warning(
|
||||
"Two processes do not agree on the P2P access"
|
||||
" status on %d -> %d, treat as disabled.",
|
||||
src,
|
||||
tgt,
|
||||
)
|
||||
result.append(False)
|
||||
else:
|
||||
result.append(a)
|
||||
return result
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# we are testing peer-to-peer (p2p) access between GPUs,across processes.
|
||||
# if we test it every time, it will be very slow, because we need to create
|
||||
# N * N * 2 processes, where N is the world size. This is very slow.
|
||||
# to reduce the time, we use a cache file to store the p2p access status.
|
||||
# the cache file is generated by the master process if it does not exist.
|
||||
# then all the processes can read the cache file to check the p2p access status.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: dict[str, bool] | None = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
"""Check if GPU src can access GPU tgt."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
path = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT, f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
|
||||
if (not is_distributed or get_world_group().local_rank == 0) and (
|
||||
not os.path.exists(path)
|
||||
):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info("generating GPU P2P access cache in %s", path)
|
||||
cache: dict[str, bool] = {}
|
||||
ids = list(range(num_dev))
|
||||
# batch of all pairs of GPUs
|
||||
batch_src, batch_tgt = zip(*list(product(ids, ids)))
|
||||
# NOTE: we use `subprocess` rather than `multiprocessing` here
|
||||
# because the caller might not have `if __name__ == "__main__":`,
|
||||
# in that case we cannot use spawn method in multiprocessing.
|
||||
# However, `can_actually_p2p` requires spawn method.
|
||||
# The fix is, we use `subprocess` to call the function,
|
||||
# where we have `if __name__ == "__main__":` in this file.
|
||||
|
||||
# use a temporary file to store the result
|
||||
# we don't use the output of the subprocess directly,
|
||||
# because the subprocess might produce logging output
|
||||
with tempfile.NamedTemporaryFile() as output_file:
|
||||
input_bytes = pickle.dumps((batch_src, batch_tgt, output_file.name))
|
||||
returned = subprocess.run(
|
||||
[sys.executable, __file__], input=input_bytes, capture_output=True
|
||||
)
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
returned.check_returncode()
|
||||
except Exception as e:
|
||||
# wrap raised exception to provide more information
|
||||
raise RuntimeError(
|
||||
f"Error happened when batch testing "
|
||||
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
|
||||
f"{returned.stderr.decode()}"
|
||||
) from e
|
||||
with open(output_file.name, "rb") as f:
|
||||
result = pickle.load(f)
|
||||
for _i, _j, r in zip(batch_src, batch_tgt, result):
|
||||
cache[f"{_i}->{_j}"] = r
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
get_world_group().barrier()
|
||||
logger.info("reading GPU P2P access cache from %s", path)
|
||||
with open(path) as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{src}->{tgt}"]
|
||||
|
||||
|
||||
__all__ = ["gpu_p2p_access_check"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
|
||||
result = can_actually_p2p(batch_src, batch_tgt)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(pickle.dumps(result))
|
||||
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
import ixformer.distributed as ixfd
|
||||
import os
|
||||
|
||||
class Cache:
|
||||
def __init__(self):
|
||||
self._cache: WeakValueDictionary = WeakValueDictionary()
|
||||
self._lock = threading.RLock() # Reentrant lock for thread safety
|
||||
|
||||
def get_or_create(self, kwargs, func):
|
||||
# Create a hashable key from the kwargs
|
||||
key = tuple(sorted((k, v) for k, v in kwargs.items()))
|
||||
|
||||
with self._lock:
|
||||
instance = self._cache.get(key)
|
||||
if instance is None:
|
||||
instance = func(**kwargs)
|
||||
self._cache[key] = instance
|
||||
return instance
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# compute some common properties
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_tp_group,
|
||||
in_the_same_node_as,
|
||||
)
|
||||
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
self.dp_world_size = self.dp_group.world_size
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
|
||||
# all2all communication often has separate implementations for
|
||||
# intra-node and inter-node communication
|
||||
self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
# get a handle for the all2all communication,
|
||||
# based on the kwargs.
|
||||
# different layers can have different configs,
|
||||
# e.g. one layer has hidden size 1024, another has 2048.
|
||||
# usually the underlying implementation caches the handle
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> int | None:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceCommunicatorBase:
|
||||
"""
|
||||
Base class for device-specific communicator.
|
||||
It can use the `cpu_group` to initialize the communicator.
|
||||
If the device has PyTorch integration (PyTorch can recognize its
|
||||
communication backend), the `device_group` will also be given.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_group: ProcessGroup,
|
||||
device: torch.device | None = None,
|
||||
device_group: ProcessGroup | None = None,
|
||||
unique_name: str = "",
|
||||
):
|
||||
self.device = device or torch.device("cpu")
|
||||
self.cpu_group = cpu_group
|
||||
self.device_group = device_group
|
||||
self.unique_name = unique_name
|
||||
self.rank = dist.get_rank(cpu_group)
|
||||
self.world_size = dist.get_world_size(cpu_group)
|
||||
self.ranks = dist.get_process_group_ranks(cpu_group)
|
||||
self.global_rank = dist.get_rank()
|
||||
self.global_world_size = dist.get_world_size()
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
config = get_current_vllm_config_or_none()
|
||||
if config is not None:
|
||||
# as long as we use data parallel (coupled data parallel
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
dist.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# NOTE: we have to use concat-style all-gather here,
|
||||
# stack-style all-gather has compatibility issues with
|
||||
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
|
||||
output_size = (input_size[0] * self.world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
# All-gather.
|
||||
if self.use_vllm_comm:
|
||||
ixfd.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||
input_,
|
||||
group=self.device_group)
|
||||
# Reshape
|
||||
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(
|
||||
input_size[:dim]
|
||||
+ (self.world_size * input_size[dim],)
|
||||
+ input_size[dim + 1 :]
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: torch.Tensor | list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
sizes: list[int] | None = None,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Note: This will produce an incorrect answer if we don't make
|
||||
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
|
||||
input_tensor = input_.movedim(0, dim).contiguous()
|
||||
|
||||
assert input_tensor.shape[0] % world_size == 0
|
||||
chunk_size = input_tensor.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_tensor.shape[1:]
|
||||
|
||||
output_tensor = torch.empty(
|
||||
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
|
||||
)
|
||||
|
||||
# Perform reduce-scatter operation
|
||||
torch.distributed.reduce_scatter_tensor(
|
||||
output_tensor, input_tensor, group=self.device_group
|
||||
)
|
||||
|
||||
# Reshape before returning
|
||||
return output_tensor.movedim(0, dim).contiguous()
|
||||
|
||||
def reduce_scatterv(
|
||||
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> torch.Tensor | None:
|
||||
"""
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
all the ranks.
|
||||
NOTE: `dst` is the local rank of the destination rank.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
|
||||
# Allocate output tensor.
|
||||
if self.rank_in_group == dst:
|
||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
else:
|
||||
gather_list = None
|
||||
# Gather.
|
||||
if self.use_vllm_comm:
|
||||
ixfd.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group,
|
||||
async_op=True)
|
||||
else:
|
||||
torch.distributed.gather(input_,
|
||||
gather_list,
|
||||
dst=self.ranks[dst],
|
||||
group=self.device_group)
|
||||
if self.rank_in_group == dst:
|
||||
output_tensor = torch.cat(gather_list, dim=dim)
|
||||
else:
|
||||
output_tensor = None
|
||||
return output_tensor
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
|
||||
"""Sends a tensor to the destination rank in a blocking way"""
|
||||
"""NOTE: `dst` is the local rank of the destination rank."""
|
||||
if dst is None:
|
||||
dst = (self.rank_in_group + 1) % self.world_size
|
||||
torch.distributed.send(tensor, self.ranks[dst], self.device_group)
|
||||
|
||||
def recv(
|
||||
self, size: torch.Size, dtype: torch.dtype, src: int | None = None
|
||||
) -> torch.Tensor:
|
||||
"""Receives a tensor from the source rank."""
|
||||
"""NOTE: `src` is the local rank of the source rank."""
|
||||
if src is None:
|
||||
src = (self.rank_in_group - 1) % self.world_size
|
||||
|
||||
tensor = torch.empty(size, dtype=dtype, device=self.device)
|
||||
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
|
||||
return tensor
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepare the communication buffer for the model.
|
||||
"""
|
||||
if not self.is_ep_communicator:
|
||||
return
|
||||
|
||||
moe_modules = [
|
||||
module
|
||||
for module in model.modules()
|
||||
# TODO(bnell): Should use isinstance but can't. Maybe search for
|
||||
# presence of quant_method.maybe_init_modular_kernel?
|
||||
if (
|
||||
module.__class__.__name__ == "FusedMoE"
|
||||
or module.__class__.__name__ == "SharedFusedMoE"
|
||||
)
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.maybe_init_modular_kernel()
|
||||
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
if extra_tensors is not None:
|
||||
return hidden_states, router_logits, extra_tensors
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
if extra_tensors is not None:
|
||||
return hidden_states, topk_weights, topk_ids, extra_tensors
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user