init src 0.9.2
This commit is contained in:
106
vllm/adapter_commons/models.py
Normal file
106
vllm/adapter_commons/models.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AdapterModel(ABC):
|
||||
|
||||
def __init__(self, model_id=None):
|
||||
self.id = model_id
|
||||
|
||||
@abstractmethod
|
||||
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
|
||||
# Common initialization code
|
||||
# Load weights or embeddings from local checkpoint
|
||||
raise NotImplementedError("Subclasses must implement this method.")
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class AdapterLRUCache(LRUCache[int, T]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_fn = deactivate_fn
|
||||
|
||||
def _on_remove(self, key: int, value: Optional[T]):
|
||||
logger.debug("Removing adapter int id: %d", key)
|
||||
self.deactivate_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
||||
|
||||
class AdapterModelManager(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
):
|
||||
"""Create a AdapterModelManager and adapter for a given model.
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
"""
|
||||
self.model: nn.Module = model
|
||||
self._registered_adapters: dict[int, Any] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self._active_adapters: dict[int, None] = {}
|
||||
self.adapter_type = 'Adapter'
|
||||
self._last_mapping = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_adapters)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def adapter_slots(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def capacity(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def activate_adapter(self, adapter_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_adapter(self, adapter: Any) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set_adapter_mapping(self, mapping: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_adapters(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_adapters(self) -> dict[int, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user