This commit is contained in:
root
2026-04-09 11:23:47 +08:00
parent 8082d5f4b2
commit 72387e4fa8
1885 changed files with 611521 additions and 1 deletions

82
vllm/plugins/__init__.py Normal file
View File

@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections.abc import Callable
from typing import Any
import vllm.envs as envs
logger = logging.getLogger(__name__)
# Default plugins group will be loaded in all processes(process0, engine core
# process and worker processes)
DEFAULT_PLUGINS_GROUP = "vllm.general_plugins"
# IO processor plugins group will be loaded in process0 only
IO_PROCESSOR_PLUGINS_GROUP = "vllm.io_processor_plugins"
# Platform plugins group will be loaded in all processes when
# `vllm.platforms.current_platform` is called and the value not initialized,
PLATFORM_PLUGINS_GROUP = "vllm.platform_plugins"
# Stat logger plugins group will be loaded in process0 only when serve vLLM with
# async mode.
STAT_LOGGER_PLUGINS_GROUP = "vllm.stat_logger_plugins"
# make sure one process only loads plugins once
plugins_loaded = False
def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
"""Load plugins registered under the given entry point group."""
from importlib.metadata import entry_points
allowed_plugins = envs.VLLM_PLUGINS
discovered_plugins = entry_points(group=group)
if len(discovered_plugins) == 0:
logger.debug("No plugins for group %s found.", group)
return {}
# Check if the only discovered plugin is the default one
is_default_group = group == DEFAULT_PLUGINS_GROUP
# Use INFO for non-default groups and DEBUG for the default group
log_level = logger.debug if is_default_group else logger.info
log_level("Available plugins for group %s:", group)
for plugin in discovered_plugins:
log_level("- %s -> %s", plugin.name, plugin.value)
if allowed_plugins is None:
log_level(
"All plugins in this group will be loaded. "
"Set `VLLM_PLUGINS` to control which plugins to load."
)
plugins = dict[str, Callable[[], Any]]()
for plugin in discovered_plugins:
if allowed_plugins is None or plugin.name in allowed_plugins:
if allowed_plugins is not None:
log_level("Loading plugin %s", plugin.name)
try:
func = plugin.load()
plugins[plugin.name] = func
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)
return plugins
def load_general_plugins():
"""WARNING: plugins can be loaded for multiple times in different
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
"""
global plugins_loaded
if plugins_loaded:
return
plugins_loaded = True
plugins = load_plugins_by_group(group=DEFAULT_PLUGINS_GROUP)
# general plugins, we only need to execute the loaded functions
for func in plugins.values():
func()

View File

@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from vllm.config import VllmConfig
from vllm.plugins import IO_PROCESSOR_PLUGINS_GROUP, load_plugins_by_group
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.utils.import_utils import resolve_obj_by_qualname
logger = logging.getLogger(__name__)
def get_io_processor(
vllm_config: VllmConfig, plugin_from_init: str | None = None
) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform
# plugins, these plugins register a function that returns the class
# name for the processor to install.
if plugin_from_init:
model_plugin = plugin_from_init
else:
# A plugin can be specified via the model config
# Retrieve the model specific plugin if available
# This is using a custom field in the hf_config for the model
hf_config = vllm_config.model_config.hf_config.to_dict()
config_plugin = hf_config.get("io_processor_plugin")
model_plugin = config_plugin
if model_plugin is None:
logger.debug("No IOProcessor plugins requested by the model")
return None
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
# Load all installed plugin in the group
multimodal_data_processor_plugins = load_plugins_by_group(
IO_PROCESSOR_PLUGINS_GROUP
)
loadable_plugins = {}
for name, func in multimodal_data_processor_plugins.items():
try:
assert callable(func)
processor_cls_qualname = func()
if processor_cls_qualname is not None:
loadable_plugins[name] = processor_cls_qualname
except Exception:
logger.warning("Failed to load plugin %s.", name, exc_info=True)
num_available_plugins = len(loadable_plugins.keys())
if num_available_plugins == 0:
raise ValueError(
f"No IOProcessor plugins installed but one is required ({model_plugin})."
)
if model_plugin not in loadable_plugins:
raise ValueError(
f"The model requires the '{model_plugin}' IO Processor plugin "
"but it is not installed. "
f"Available plugins: {list(loadable_plugins.keys())}"
)
activated_plugin_cls = loadable_plugins[model_plugin]
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Generic, TypeVar
from vllm.config import VllmConfig
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
"""Abstract interface for pre/post-processing of engine I/O."""
def __init__(self, vllm_config: VllmConfig):
super().__init__()
self.vllm_config = vllm_config
def parse_data(self, data: object) -> IOProcessorInput:
if callable(parse_request := getattr(self, "parse_request", None)):
warnings.warn(
"`parse_request` has been renamed to `parse_data`. "
"Please update your IO Processor Plugin to use the new name. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return parse_request(data) # type: ignore
raise NotImplementedError
def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or SamplingParams()
def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or PoolingParams(task="plugin")
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
# We cannot guarantee outputs are returned in the same order they were
# fed to vLLM.
# Let's sort them by id before post_processing
sorted_output = sorted(
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id=request_id, **kwargs)

View File

View File

@@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import vllm.envs as envs
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
class FilesystemResolver(LoRAResolver):
def __init__(self, lora_cache_dir: str):
self.lora_cache_dir = lora_cache_dir
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
lora_path = os.path.join(self.lora_cache_dir, lora_name)
maybe_lora_request = await self._get_lora_req_from_path(
lora_name, lora_path, base_model_name
)
return maybe_lora_request
async def _get_lora_req_from_path(
self, lora_name: str, lora_path: str, base_model_name: str
) -> LoRARequest | None:
"""Builds a LoraRequest pointing to the lora path if it's a valid
LoRA adapter and has a matching base_model_name.
"""
if os.path.exists(lora_path):
adapter_config_path = os.path.join(lora_path, "adapter_config.json")
if os.path.exists(adapter_config_path):
with open(adapter_config_path) as file:
adapter_config = json.load(file)
if (
adapter_config["peft_type"] == "LORA"
and adapter_config["base_model_name_or_path"] == base_model_name
):
lora_request = LoRARequest(
lora_name=lora_name,
lora_int_id=abs(hash(lora_name)),
lora_path=lora_path,
)
return lora_request
return None
def register_filesystem_resolver():
"""Register the filesystem LoRA Resolver with vLLM"""
lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR
if lora_cache_dir:
if not os.path.exists(lora_cache_dir) or not os.path.isdir(lora_cache_dir):
raise ValueError(
"VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \
for Filesystem Resolver plugin to function"
)
fs_resolver = FilesystemResolver(lora_cache_dir)
LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver)
return

View File

@@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
from huggingface_hub import HfApi, snapshot_download
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolverRegistry
from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver
logger = init_logger(__name__)
class HfHubResolver(FilesystemResolver):
def __init__(self, repo_list: list[str]):
logger.warning(
"LoRA is allowing resolution from the following repositories on"
" HF Hub: %s please note that allowing remote downloads"
" is not secure, and that this plugin is not intended for use in"
" production environments.",
repo_list,
)
self.repo_list: list[str] = repo_list
self.adapter_dirs: dict[str, set[str]] = {}
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
"""Resolves potential LoRA requests in a remote repo on HF Hub.
This is effectively the same behavior as the filesystem resolver, but
with a snapshot_download on dirs containing an adapter config prior
to inspecting the cached dir to build a potential LoRA
request.
"""
# If a LoRA name begins with the repository name, it's disambiguated
maybe_repo = await self._resolve_repo(lora_name)
# If we haven't inspected this repo before, save available adapter dirs
if maybe_repo is not None and maybe_repo not in self.adapter_dirs:
self.adapter_dirs[maybe_repo] = await self._get_adapter_dirs(maybe_repo)
maybe_subpath = await self._resolve_repo_subpath(lora_name, maybe_repo)
if maybe_repo is None or maybe_subpath is None:
return None
repo_path = await asyncio.to_thread(
snapshot_download,
repo_id=maybe_repo,
allow_patterns=f"{maybe_subpath}/*" if maybe_subpath != "." else "*",
)
lora_path = os.path.join(repo_path, maybe_subpath)
maybe_lora_request = await self._get_lora_req_from_path(
lora_name, lora_path, base_model_name
)
return maybe_lora_request
async def _resolve_repo(self, lora_name: str) -> str | None:
"""Given a fully qualified path to a LoRA with respect to its HF Hub
repo, match the right repo to potentially download from if one exists.
Args:
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>,
match on <org>/<repo> (if it contains an adapter directly) or
<org>/<repo>/ if it may have one in subdirs.
"""
for potential_repo in self.repo_list:
if lora_name.startswith(potential_repo) and (
len(lora_name) == len(potential_repo)
or lora_name[len(potential_repo)] == "/"
):
return potential_repo
return None
async def _resolve_repo_subpath(
self, lora_name: str, maybe_repo: str | None
) -> str | None:
"""Given the fully qualified path of the LoRA with respect to the HF
Repo, get the subpath to download from assuming it's actually got an
adapter in it.
Args:
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>
maybe_repo: Path to the repo to match against if one exists.
"""
if maybe_repo is None:
return None
repo_len = len(maybe_repo)
if lora_name == maybe_repo or (
len(lora_name) == repo_len + 1 and lora_name[-1] == "/"
):
# Resolves to the root of the directory
adapter_dir = "."
else:
# It's a subpath; removing trailing slashes if there are any
adapter_dir = lora_name[repo_len + 1 :].rstrip("/")
# Only download if the directory actually contains an adapter
is_adapter = adapter_dir in self.adapter_dirs[maybe_repo]
return adapter_dir if is_adapter else None
async def _get_adapter_dirs(self, repo_name: str) -> set[str]:
"""Gets the subpaths within a HF repo that contain an adapter config.
Args:
repo_name: Name of the HF hub repo to inspect.
"""
repo_files = await asyncio.to_thread(HfApi().list_repo_files, repo_id=repo_name)
adapter_dirs = {
os.path.dirname(name)
for name in repo_files
if name.endswith("adapter_config.json")
}
if "adapter_config.json" in repo_files:
adapter_dirs.add(".")
return adapter_dirs
def register_hf_hub_resolver():
"""Register the Hf hub LoRA Resolver with vLLM"""
hf_repo_list = envs.VLLM_LORA_RESOLVER_HF_REPO_LIST
is_enabled = (
envs.VLLM_PLUGINS is not None and "lora_hf_hub_resolver" in envs.VLLM_PLUGINS
)
if hf_repo_list:
if not is_enabled:
logger.warning(
"It appears that VLLM_LORA_RESOLVER_HF_REPO_LIST is set, but "
"lora_hf_hub_resolver is not enabled in VLLM_PLUGINS; you must"
" enable this resolver directly in VLLM_PLUGINS to use it "
" because it allows remote downloads."
)
else:
hf_hub_resolver = HfHubResolver(hf_repo_list.split(","))
LoRAResolverRegistry.register_resolver("Hf Hub Resolver", hf_hub_resolver)
return