update
This commit is contained in:
82
vllm/plugins/__init__.py
Normal file
82
vllm/plugins/__init__.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default plugins group will be loaded in all processes(process0, engine core
|
||||
# process and worker processes)
|
||||
DEFAULT_PLUGINS_GROUP = "vllm.general_plugins"
|
||||
# IO processor plugins group will be loaded in process0 only
|
||||
IO_PROCESSOR_PLUGINS_GROUP = "vllm.io_processor_plugins"
|
||||
# Platform plugins group will be loaded in all processes when
|
||||
# `vllm.platforms.current_platform` is called and the value not initialized,
|
||||
PLATFORM_PLUGINS_GROUP = "vllm.platform_plugins"
|
||||
# Stat logger plugins group will be loaded in process0 only when serve vLLM with
|
||||
# async mode.
|
||||
STAT_LOGGER_PLUGINS_GROUP = "vllm.stat_logger_plugins"
|
||||
|
||||
# make sure one process only loads plugins once
|
||||
plugins_loaded = False
|
||||
|
||||
|
||||
def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
|
||||
"""Load plugins registered under the given entry point group."""
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
allowed_plugins = envs.VLLM_PLUGINS
|
||||
|
||||
discovered_plugins = entry_points(group=group)
|
||||
if len(discovered_plugins) == 0:
|
||||
logger.debug("No plugins for group %s found.", group)
|
||||
return {}
|
||||
|
||||
# Check if the only discovered plugin is the default one
|
||||
is_default_group = group == DEFAULT_PLUGINS_GROUP
|
||||
# Use INFO for non-default groups and DEBUG for the default group
|
||||
log_level = logger.debug if is_default_group else logger.info
|
||||
|
||||
log_level("Available plugins for group %s:", group)
|
||||
for plugin in discovered_plugins:
|
||||
log_level("- %s -> %s", plugin.name, plugin.value)
|
||||
|
||||
if allowed_plugins is None:
|
||||
log_level(
|
||||
"All plugins in this group will be loaded. "
|
||||
"Set `VLLM_PLUGINS` to control which plugins to load."
|
||||
)
|
||||
|
||||
plugins = dict[str, Callable[[], Any]]()
|
||||
for plugin in discovered_plugins:
|
||||
if allowed_plugins is None or plugin.name in allowed_plugins:
|
||||
if allowed_plugins is not None:
|
||||
log_level("Loading plugin %s", plugin.name)
|
||||
|
||||
try:
|
||||
func = plugin.load()
|
||||
plugins[plugin.name] = func
|
||||
except Exception:
|
||||
logger.exception("Failed to load plugin %s", plugin.name)
|
||||
|
||||
return plugins
|
||||
|
||||
|
||||
def load_general_plugins():
|
||||
"""WARNING: plugins can be loaded for multiple times in different
|
||||
processes. They should be designed in a way that they can be loaded
|
||||
multiple times without causing issues.
|
||||
"""
|
||||
global plugins_loaded
|
||||
if plugins_loaded:
|
||||
return
|
||||
plugins_loaded = True
|
||||
|
||||
plugins = load_plugins_by_group(group=DEFAULT_PLUGINS_GROUP)
|
||||
# general plugins, we only need to execute the loaded functions
|
||||
for func in plugins.values():
|
||||
func()
|
||||
68
vllm/plugins/io_processors/__init__.py
Normal file
68
vllm/plugins/io_processors/__init__.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.plugins import IO_PROCESSOR_PLUGINS_GROUP, load_plugins_by_group
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_io_processor(
|
||||
vllm_config: VllmConfig, plugin_from_init: str | None = None
|
||||
) -> IOProcessor | None:
|
||||
# Input.Output processors are loaded as plugins under the
|
||||
# 'vllm.io_processor_plugins' group. Similar to platform
|
||||
# plugins, these plugins register a function that returns the class
|
||||
# name for the processor to install.
|
||||
|
||||
if plugin_from_init:
|
||||
model_plugin = plugin_from_init
|
||||
else:
|
||||
# A plugin can be specified via the model config
|
||||
# Retrieve the model specific plugin if available
|
||||
# This is using a custom field in the hf_config for the model
|
||||
hf_config = vllm_config.model_config.hf_config.to_dict()
|
||||
config_plugin = hf_config.get("io_processor_plugin")
|
||||
model_plugin = config_plugin
|
||||
|
||||
if model_plugin is None:
|
||||
logger.debug("No IOProcessor plugins requested by the model")
|
||||
return None
|
||||
|
||||
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
|
||||
|
||||
# Load all installed plugin in the group
|
||||
multimodal_data_processor_plugins = load_plugins_by_group(
|
||||
IO_PROCESSOR_PLUGINS_GROUP
|
||||
)
|
||||
|
||||
loadable_plugins = {}
|
||||
for name, func in multimodal_data_processor_plugins.items():
|
||||
try:
|
||||
assert callable(func)
|
||||
processor_cls_qualname = func()
|
||||
if processor_cls_qualname is not None:
|
||||
loadable_plugins[name] = processor_cls_qualname
|
||||
except Exception:
|
||||
logger.warning("Failed to load plugin %s.", name, exc_info=True)
|
||||
|
||||
num_available_plugins = len(loadable_plugins.keys())
|
||||
if num_available_plugins == 0:
|
||||
raise ValueError(
|
||||
f"No IOProcessor plugins installed but one is required ({model_plugin})."
|
||||
)
|
||||
|
||||
if model_plugin not in loadable_plugins:
|
||||
raise ValueError(
|
||||
f"The model requires the '{model_plugin}' IO Processor plugin "
|
||||
"but it is not installed. "
|
||||
f"Available plugins: {list(loadable_plugins.keys())}"
|
||||
)
|
||||
|
||||
activated_plugin_cls = loadable_plugins[model_plugin]
|
||||
|
||||
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)
|
||||
123
vllm/plugins/io_processors/interface.py
Normal file
123
vllm/plugins/io_processors/interface.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
IOProcessorInput = TypeVar("IOProcessorInput")
|
||||
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||
|
||||
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
"""Abstract interface for pre/post-processing of engine I/O."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__()
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
def parse_data(self, data: object) -> IOProcessorInput:
|
||||
if callable(parse_request := getattr(self, "parse_request", None)):
|
||||
warnings.warn(
|
||||
"`parse_request` has been renamed to `parse_data`. "
|
||||
"Please update your IO Processor Plugin to use the new name. "
|
||||
"The old name will be removed in v0.19.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return parse_request(data) # type: ignore
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def merge_sampling_params(
|
||||
self,
|
||||
params: SamplingParams | None = None,
|
||||
) -> SamplingParams:
|
||||
if callable(
|
||||
validate_or_generate_params := getattr(
|
||||
self, "validate_or_generate_params", None
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"`validate_or_generate_params` has been split into "
|
||||
"`merge_sampling_params` and `merge_pooling_params`."
|
||||
"Please update your IO Processor Plugin to use the new methods. "
|
||||
"The old name will be removed in v0.19.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return validate_or_generate_params(params) # type: ignore
|
||||
|
||||
return params or SamplingParams()
|
||||
|
||||
def merge_pooling_params(
|
||||
self,
|
||||
params: PoolingParams | None = None,
|
||||
) -> PoolingParams:
|
||||
if callable(
|
||||
validate_or_generate_params := getattr(
|
||||
self, "validate_or_generate_params", None
|
||||
)
|
||||
):
|
||||
warnings.warn(
|
||||
"`validate_or_generate_params` has been split into "
|
||||
"`merge_sampling_params` and `merge_pooling_params`."
|
||||
"Please update your IO Processor Plugin to use the new methods. "
|
||||
"The old name will be removed in v0.19.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return validate_or_generate_params(params) # type: ignore
|
||||
|
||||
return params or PoolingParams(task="plugin")
|
||||
|
||||
@abstractmethod
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> PromptType | Sequence[PromptType]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_async(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> PromptType | Sequence[PromptType]:
|
||||
return self.pre_process(prompt, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def post_process(
|
||||
self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
async def post_process_async(
|
||||
self,
|
||||
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
# We cannot guarantee outputs are returned in the same order they were
|
||||
# fed to vLLM.
|
||||
# Let's sort them by id before post_processing
|
||||
sorted_output = sorted(
|
||||
[(i, item) async for i, item in model_output], key=lambda output: output[0]
|
||||
)
|
||||
collected_output = [output[1] for output in sorted_output]
|
||||
return self.post_process(collected_output, request_id=request_id, **kwargs)
|
||||
0
vllm/plugins/lora_resolvers/__init__.py
Normal file
0
vllm/plugins/lora_resolvers/__init__.py
Normal file
62
vllm/plugins/lora_resolvers/filesystem_resolver.py
Normal file
62
vllm/plugins/lora_resolvers/filesystem_resolver.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import os
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
|
||||
|
||||
class FilesystemResolver(LoRAResolver):
|
||||
def __init__(self, lora_cache_dir: str):
|
||||
self.lora_cache_dir = lora_cache_dir
|
||||
|
||||
async def resolve_lora(
|
||||
self, base_model_name: str, lora_name: str
|
||||
) -> LoRARequest | None:
|
||||
lora_path = os.path.join(self.lora_cache_dir, lora_name)
|
||||
maybe_lora_request = await self._get_lora_req_from_path(
|
||||
lora_name, lora_path, base_model_name
|
||||
)
|
||||
return maybe_lora_request
|
||||
|
||||
async def _get_lora_req_from_path(
|
||||
self, lora_name: str, lora_path: str, base_model_name: str
|
||||
) -> LoRARequest | None:
|
||||
"""Builds a LoraRequest pointing to the lora path if it's a valid
|
||||
LoRA adapter and has a matching base_model_name.
|
||||
"""
|
||||
if os.path.exists(lora_path):
|
||||
adapter_config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
|
||||
if os.path.exists(adapter_config_path):
|
||||
with open(adapter_config_path) as file:
|
||||
adapter_config = json.load(file)
|
||||
if (
|
||||
adapter_config["peft_type"] == "LORA"
|
||||
and adapter_config["base_model_name_or_path"] == base_model_name
|
||||
):
|
||||
lora_request = LoRARequest(
|
||||
lora_name=lora_name,
|
||||
lora_int_id=abs(hash(lora_name)),
|
||||
lora_path=lora_path,
|
||||
)
|
||||
return lora_request
|
||||
return None
|
||||
|
||||
|
||||
def register_filesystem_resolver():
|
||||
"""Register the filesystem LoRA Resolver with vLLM"""
|
||||
|
||||
lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR
|
||||
if lora_cache_dir:
|
||||
if not os.path.exists(lora_cache_dir) or not os.path.isdir(lora_cache_dir):
|
||||
raise ValueError(
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \
|
||||
for Filesystem Resolver plugin to function"
|
||||
)
|
||||
fs_resolver = FilesystemResolver(lora_cache_dir)
|
||||
LoRAResolverRegistry.register_resolver("Filesystem Resolver", fs_resolver)
|
||||
|
||||
return
|
||||
143
vllm/plugins/lora_resolvers/hf_hub_resolver.py
Normal file
143
vllm/plugins/lora_resolvers/hf_hub_resolver.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolverRegistry
|
||||
from vllm.plugins.lora_resolvers.filesystem_resolver import FilesystemResolver
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class HfHubResolver(FilesystemResolver):
|
||||
def __init__(self, repo_list: list[str]):
|
||||
logger.warning(
|
||||
"LoRA is allowing resolution from the following repositories on"
|
||||
" HF Hub: %s please note that allowing remote downloads"
|
||||
" is not secure, and that this plugin is not intended for use in"
|
||||
" production environments.",
|
||||
repo_list,
|
||||
)
|
||||
|
||||
self.repo_list: list[str] = repo_list
|
||||
self.adapter_dirs: dict[str, set[str]] = {}
|
||||
|
||||
async def resolve_lora(
|
||||
self, base_model_name: str, lora_name: str
|
||||
) -> LoRARequest | None:
|
||||
"""Resolves potential LoRA requests in a remote repo on HF Hub.
|
||||
This is effectively the same behavior as the filesystem resolver, but
|
||||
with a snapshot_download on dirs containing an adapter config prior
|
||||
to inspecting the cached dir to build a potential LoRA
|
||||
request.
|
||||
"""
|
||||
# If a LoRA name begins with the repository name, it's disambiguated
|
||||
maybe_repo = await self._resolve_repo(lora_name)
|
||||
|
||||
# If we haven't inspected this repo before, save available adapter dirs
|
||||
if maybe_repo is not None and maybe_repo not in self.adapter_dirs:
|
||||
self.adapter_dirs[maybe_repo] = await self._get_adapter_dirs(maybe_repo)
|
||||
|
||||
maybe_subpath = await self._resolve_repo_subpath(lora_name, maybe_repo)
|
||||
|
||||
if maybe_repo is None or maybe_subpath is None:
|
||||
return None
|
||||
|
||||
repo_path = await asyncio.to_thread(
|
||||
snapshot_download,
|
||||
repo_id=maybe_repo,
|
||||
allow_patterns=f"{maybe_subpath}/*" if maybe_subpath != "." else "*",
|
||||
)
|
||||
|
||||
lora_path = os.path.join(repo_path, maybe_subpath)
|
||||
maybe_lora_request = await self._get_lora_req_from_path(
|
||||
lora_name, lora_path, base_model_name
|
||||
)
|
||||
return maybe_lora_request
|
||||
|
||||
async def _resolve_repo(self, lora_name: str) -> str | None:
|
||||
"""Given a fully qualified path to a LoRA with respect to its HF Hub
|
||||
repo, match the right repo to potentially download from if one exists.
|
||||
|
||||
Args:
|
||||
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>,
|
||||
match on <org>/<repo> (if it contains an adapter directly) or
|
||||
<org>/<repo>/ if it may have one in subdirs.
|
||||
"""
|
||||
for potential_repo in self.repo_list:
|
||||
if lora_name.startswith(potential_repo) and (
|
||||
len(lora_name) == len(potential_repo)
|
||||
or lora_name[len(potential_repo)] == "/"
|
||||
):
|
||||
return potential_repo
|
||||
return None
|
||||
|
||||
async def _resolve_repo_subpath(
|
||||
self, lora_name: str, maybe_repo: str | None
|
||||
) -> str | None:
|
||||
"""Given the fully qualified path of the LoRA with respect to the HF
|
||||
Repo, get the subpath to download from assuming it's actually got an
|
||||
adapter in it.
|
||||
|
||||
Args:
|
||||
lora_name: Path to LoRA in HF Hub, e.g., <org>/<repo>/<subpath>
|
||||
maybe_repo: Path to the repo to match against if one exists.
|
||||
"""
|
||||
if maybe_repo is None:
|
||||
return None
|
||||
repo_len = len(maybe_repo)
|
||||
if lora_name == maybe_repo or (
|
||||
len(lora_name) == repo_len + 1 and lora_name[-1] == "/"
|
||||
):
|
||||
# Resolves to the root of the directory
|
||||
adapter_dir = "."
|
||||
else:
|
||||
# It's a subpath; removing trailing slashes if there are any
|
||||
adapter_dir = lora_name[repo_len + 1 :].rstrip("/")
|
||||
|
||||
# Only download if the directory actually contains an adapter
|
||||
is_adapter = adapter_dir in self.adapter_dirs[maybe_repo]
|
||||
return adapter_dir if is_adapter else None
|
||||
|
||||
async def _get_adapter_dirs(self, repo_name: str) -> set[str]:
|
||||
"""Gets the subpaths within a HF repo that contain an adapter config.
|
||||
|
||||
Args:
|
||||
repo_name: Name of the HF hub repo to inspect.
|
||||
"""
|
||||
repo_files = await asyncio.to_thread(HfApi().list_repo_files, repo_id=repo_name)
|
||||
adapter_dirs = {
|
||||
os.path.dirname(name)
|
||||
for name in repo_files
|
||||
if name.endswith("adapter_config.json")
|
||||
}
|
||||
if "adapter_config.json" in repo_files:
|
||||
adapter_dirs.add(".")
|
||||
return adapter_dirs
|
||||
|
||||
|
||||
def register_hf_hub_resolver():
|
||||
"""Register the Hf hub LoRA Resolver with vLLM"""
|
||||
|
||||
hf_repo_list = envs.VLLM_LORA_RESOLVER_HF_REPO_LIST
|
||||
is_enabled = (
|
||||
envs.VLLM_PLUGINS is not None and "lora_hf_hub_resolver" in envs.VLLM_PLUGINS
|
||||
)
|
||||
if hf_repo_list:
|
||||
if not is_enabled:
|
||||
logger.warning(
|
||||
"It appears that VLLM_LORA_RESOLVER_HF_REPO_LIST is set, but "
|
||||
"lora_hf_hub_resolver is not enabled in VLLM_PLUGINS; you must"
|
||||
" enable this resolver directly in VLLM_PLUGINS to use it "
|
||||
" because it allows remote downloads."
|
||||
)
|
||||
else:
|
||||
hf_hub_resolver = HfHubResolver(hf_repo_list.split(","))
|
||||
LoRAResolverRegistry.register_resolver("Hf Hub Resolver", hf_hub_resolver)
|
||||
|
||||
return
|
||||
Reference in New Issue
Block a user