This commit is contained in:
2025-08-13 19:46:19 +08:00
commit 5d2e7edf78
1232 changed files with 361215 additions and 0 deletions

View File

16
adapter_commons/layers.py Normal file
View File

@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: tuple[int, ...]
# Per sampled token:
prompt_mapping: tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)

106
adapter_commons/models.py Normal file
View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, TypeVar
from torch import nn
from vllm.logger import init_logger
from vllm.utils import LRUCache
logger = init_logger(__name__)
class AdapterModel(ABC):
def __init__(self, model_id=None):
self.id = model_id
@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")
T = TypeVar('T')
class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
class AdapterModelManager(ABC):
def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None
def __len__(self) -> int:
return len(self._registered_adapters)
@property
@abstractmethod
def adapter_slots(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def capacity(self) -> int:
raise NotImplementedError
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
raise NotImplementedError
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> dict[int, Any]:
raise NotImplementedError
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
class AdapterRequest(ABC):
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self) -> int:
raise NotImplementedError
def __post_init__(self) -> None:
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id
def __hash__(self) -> int:
return hash(self.adapter_id)

93
adapter_commons/utils.py Normal file
View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
## model functions
def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None],
deactivate_func: Callable) -> bool:
if adapter_id in active_adapters:
deactivate_func(adapter_id)
active_adapters.pop(adapter_id)
return True
return False
def add_adapter(adapter: Any, registered_adapters: dict[int, Any],
capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity:
raise RuntimeError('No free adapter slots.')
add_func(adapter)
registered_adapters[adapter.id] = adapter
return True
return False
def set_adapter_mapping(mapping: Any, last_mapping: Any,
set_mapping_func: Callable) -> Any:
if last_mapping != mapping:
set_mapping_func(mapping)
return mapping
return last_mapping
def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any],
deactivate_func: Callable) -> bool:
deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None))
def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]:
return dict(registered_adapters)
def get_adapter(adapter_id: int,
registered_adapters: dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id)
## worker functions
def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any],
apply_adapters_func,
set_adapter_mapping_func) -> None:
apply_adapters_func(requests)
set_adapter_mapping_func(mapping)
def add_adapter_worker(adapter_request: Any, list_adapters_func,
load_adapter_func, add_adapter_func,
activate_adapter_func) -> bool:
if adapter_request.adapter_id in list_adapters_func():
return False
loaded_adapter = load_adapter_func(adapter_request)
loaded = add_adapter_func(loaded_adapter)
activate_adapter_func(loaded_adapter.id)
return loaded
def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None:
models_that_exist = list_adapters_func()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests if adapter_request
}
if len(models_map) > adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
f"than the number of GPU model slots "
f"({adapter_slots}).")
new_models = set(models_map)
models_to_add = new_models - models_that_exist
models_to_remove = models_that_exist - new_models
for adapter_id in models_to_remove:
remove_adapter_func(adapter_id)
for adapter_id in models_to_add:
add_adapter_func(models_map[adapter_id])
def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]:
return set(adapter_manager_list_adapters_func())

View File

@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
class AbstractWorkerManager(ABC):
def __init__(self, device: torch.device):
self.device = device
@property
@abstractmethod
def is_enabled(self) -> bool:
raise NotImplementedError
@abstractmethod
def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None:
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> set[int]:
raise NotImplementedError