forked from EngineX-Hygon/enginex-hygon-vllm
init src 0.9.2
This commit is contained in:
0
vllm/prompt_adapter/__init__.py
Normal file
0
vllm/prompt_adapter/__init__.py
Normal file
83
vllm/prompt_adapter/layers.py
Normal file
83
vllm/prompt_adapter/layers.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.layers import AdapterMapping
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterMapping(AdapterMapping):
|
||||
pass
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithPromptAdapter(nn.Module):
|
||||
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.emb_layer = self.base_layer
|
||||
if 'LoRA' in base_layer.__class__.__name__:
|
||||
self.emb_layer = self.base_layer.base_layer
|
||||
|
||||
def create_prompt_adapter_weights(
|
||||
self, prompt_adapter_config: PromptAdapterConfig):
|
||||
self.embeddings_tensors = torch.zeros(
|
||||
(
|
||||
prompt_adapter_config.max_prompt_adapters,
|
||||
prompt_adapter_config.max_prompt_adapter_token,
|
||||
self.emb_layer.embedding_dim,
|
||||
),
|
||||
dtype=self.emb_layer.weight.dtype,
|
||||
device=self.emb_layer.weight.device,
|
||||
)
|
||||
self.adapter_lengths = torch.zeros(
|
||||
prompt_adapter_config.max_prompt_adapters,
|
||||
dtype=torch.long,
|
||||
device=self.emb_layer.weight.device)
|
||||
|
||||
self.indices_gpu: torch.Tensor
|
||||
self.embedding_indices_gpu: torch.Tensor
|
||||
|
||||
def reset_prompt_adapter(self, index: int):
|
||||
self.embeddings_tensors[index] = 0
|
||||
|
||||
def set_prompt_adapter(
|
||||
self,
|
||||
index: int,
|
||||
adapter_model: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_prompt_adapter(index)
|
||||
if adapter_model is not None:
|
||||
length = adapter_model.shape[0]
|
||||
self.embeddings_tensors[index, :length] = adapter_model
|
||||
self.adapter_lengths[index] = length
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
prompt_indices: torch.Tensor,
|
||||
prompt_embedding_indices: torch.Tensor,
|
||||
):
|
||||
self.indices_gpu = prompt_indices.to(
|
||||
device=self.emb_layer.weight.device)
|
||||
self.embedding_indices_gpu = prompt_embedding_indices.to(
|
||||
device=self.emb_layer.weight.device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.base_layer(x)
|
||||
if self.embedding_indices_gpu.ndim > 1:
|
||||
valid_mask = self.indices_gpu != -1
|
||||
gathered_embeddings = self.embeddings_tensors[
|
||||
self.embedding_indices_gpu[:, 0],
|
||||
self.embedding_indices_gpu[:, 1]]
|
||||
|
||||
# Update hidden states
|
||||
hidden_states[valid_mask] = gathered_embeddings
|
||||
return hidden_states
|
||||
358
vllm/prompt_adapter/models.py
Normal file
358
vllm/prompt_adapter/models.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
||||
AdapterModelManager)
|
||||
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||
get_adapter, list_adapters,
|
||||
remove_adapter, set_adapter_mapping)
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.prompt_adapter.layers import (
|
||||
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
|
||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||
from vllm.prompt_adapter.utils import load_peft_weights
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GLOBAL_PROMPT_ADAPTER_ID = 0
|
||||
|
||||
|
||||
def get_prompt_adapter_id():
|
||||
global _GLOBAL_PROMPT_ADAPTER_ID
|
||||
_GLOBAL_PROMPT_ADAPTER_ID += 1
|
||||
return _GLOBAL_PROMPT_ADAPTER_ID
|
||||
|
||||
|
||||
def convert_to_embedding_indices(indices):
|
||||
embedding_indices = []
|
||||
count = 0
|
||||
|
||||
for value in indices:
|
||||
if value == -1:
|
||||
count = 0
|
||||
else:
|
||||
embedding_indices.append([value, count])
|
||||
count += 1
|
||||
|
||||
return torch.tensor(embedding_indices)
|
||||
|
||||
|
||||
def convert_mapping(
|
||||
mapping: PromptAdapterMapping,
|
||||
prompt_adapter_index_to_id: List[Optional[int]],
|
||||
) -> torch.Tensor:
|
||||
"""Converts PromptAdapterMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: PromptAdapterMapping mapping rows in a
|
||||
batch to PromptAdapter ids.
|
||||
prompt_adapter_index_to_id: List mapping PromptAdapter
|
||||
ids to PromptAdapter indices.
|
||||
|
||||
Returns:
|
||||
pa_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
PromptAdapter indices.
|
||||
"""
|
||||
id_to_index = {
|
||||
id_: idx
|
||||
for idx, id_ in enumerate(prompt_adapter_index_to_id)
|
||||
if id_ is not None
|
||||
}
|
||||
pa_indices = ([
|
||||
id_to_index.get(id_, -1) if id_ > 0 else -1
|
||||
for id_ in mapping.index_mapping
|
||||
])
|
||||
|
||||
pa_embedding_mapping = convert_to_embedding_indices(pa_indices)
|
||||
pa_indices = torch.tensor(pa_indices)
|
||||
return pa_indices, pa_embedding_mapping
|
||||
|
||||
|
||||
class PromptAdapterModel(AdapterModel):
|
||||
|
||||
def __init__(self,
|
||||
prompt_adapter_id=None,
|
||||
num_virtual_tokens=None,
|
||||
prompt_embedding=None) -> None:
|
||||
self.id = prompt_adapter_id
|
||||
self.prompt_embedding = prompt_embedding
|
||||
self.num_virtual_tokens = num_virtual_tokens
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
adapter_model_path: str,
|
||||
prompt_adapter_id: int,
|
||||
num_virtual_tokens: int,
|
||||
config: PromptAdapterConfig,
|
||||
device: str = "cuda",
|
||||
) -> "PromptAdapterModel":
|
||||
|
||||
if num_virtual_tokens > config.max_prompt_adapter_token:
|
||||
raise ValueError(
|
||||
f'num_virtual_tokens ({num_virtual_tokens}) should be <= '
|
||||
f'max_prompt_adapter_token({config.max_prompt_adapter_token})')
|
||||
|
||||
adapters_weights = load_peft_weights(adapter_model_path, device)
|
||||
prompt_embedding = adapters_weights["prompt_embeddings"].to(
|
||||
config.prompt_adapter_dtype)
|
||||
|
||||
return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding)
|
||||
|
||||
|
||||
class PromptAdapterModelManager(AdapterModelManager):
|
||||
"""A manager that manages multiple Prompt Adapter models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
):
|
||||
"""Create a PromptAdapterModel and adapter for a given model.
|
||||
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
max_num_seqs: the maximum number of sequences model can run in a
|
||||
single batch.
|
||||
max_num_batched_tokens: the maximum number of tokens model can run
|
||||
in a single batch.
|
||||
prompt_adapter_config: the PromptAdapter config,
|
||||
"""
|
||||
self.model: nn.Module = model
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self.prompt_adapter_index_to_id: List[
|
||||
Optional[int]] = [None] * self.prompt_adapter_slots
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.model.prompt_adapter_manager = self
|
||||
self.adapter_type = 'PromptAdapter'
|
||||
|
||||
self.base_indices = torch.tensor([-1])
|
||||
self.base_embedding_indices = torch.tensor([])
|
||||
|
||||
self.modules: Dict[str, nn.Module] = {}
|
||||
self._create_prompt_adapter_modules()
|
||||
self._last_mapping: Optional[PromptAdapterMapping] = None
|
||||
|
||||
@property
|
||||
def prompt_adapter_slots(self) -> int:
|
||||
return self.prompt_adapter_config.max_prompt_adapters
|
||||
|
||||
@property
|
||||
def adapter_slots(self) -> int:
|
||||
return self.prompt_adapter_slots
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
return self.prompt_adapter_config.max_cpu_prompt_adapters
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
prompt_adapter_id: int,
|
||||
) -> bool:
|
||||
"""Move PromptAdapter into a GPU buffer
|
||||
to be used in the forward pass."""
|
||||
if prompt_adapter_id in self._active_adapters:
|
||||
return False
|
||||
first_free_slot = next(
|
||||
((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate(
|
||||
self.prompt_adapter_index_to_id) if prompt_adapter_id is None),
|
||||
None)
|
||||
if first_free_slot is None:
|
||||
raise ValueError("No free prompt_adapter slots")
|
||||
index, _ = first_free_slot
|
||||
self._active_adapters[prompt_adapter_id] = None
|
||||
prompt_adapter_model = (self._registered_adapters[prompt_adapter_id])
|
||||
logger.debug("Activating prompt_adapter. int id: %d, slot index: %d",
|
||||
prompt_adapter_model.id, index)
|
||||
self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id
|
||||
for _, v in self.modules.items():
|
||||
v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding)
|
||||
return True
|
||||
|
||||
def _deactivate_adapter(self, prompt_adapter_id: int):
|
||||
try:
|
||||
index = self.prompt_adapter_index_to_id.index(prompt_adapter_id)
|
||||
self.prompt_adapter_index_to_id[index] = None
|
||||
for _, v in self.modules.items():
|
||||
v.reset_prompt_adapter(index)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _add_adapter(self, prompt_adapter: PromptAdapterModel):
|
||||
self._registered_adapters[prompt_adapter.id] = prompt_adapter
|
||||
|
||||
def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
||||
base_indices, base_embedding_indices = convert_mapping(
|
||||
mapping, self.prompt_adapter_index_to_id)
|
||||
for k, v in self.modules.items():
|
||||
v.set_mapping(base_indices, base_embedding_indices)
|
||||
|
||||
def _create_prompt_adapter_modules(self):
|
||||
for module_name, module in self.model.named_modules(
|
||||
remove_duplicate=False):
|
||||
if "VocabParallel" in module.__class__.__name__:
|
||||
new_module = VocabParallelEmbeddingWithPromptAdapter(module)
|
||||
new_module.create_prompt_adapter_weights(
|
||||
self.prompt_adapter_config)
|
||||
replaced_module = self.replace_submodule(
|
||||
self.model, module_name, new_module)
|
||||
self.register_module(module.__class__.__name__,
|
||||
replaced_module)
|
||||
replaced_module.set_mapping(self.base_indices,
|
||||
self.base_embedding_indices)
|
||||
break
|
||||
|
||||
def replace_submodule(self, model: nn.Module, module_name: str,
|
||||
new_module: nn.Module) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
def register_module(self, module_name: str, module: nn.Module):
|
||||
self.modules[module_name] = module
|
||||
|
||||
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
"""Pin a PromptAdapterModel in the manager cache."""
|
||||
raise NotImplementedError(
|
||||
"Pinning is not supported in PromptAdapterModelManager. "
|
||||
"Use LRUCachePromptAdapterModelManager for pinning"
|
||||
) # type: ignore
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all PromptAdapterModel from the manager."""
|
||||
self._registered_adapters.clear()
|
||||
self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots
|
||||
self._active_adapters.clear()
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
self._deactivate_adapter)
|
||||
|
||||
def add_adapter(self, adapter: PromptAdapterModel) -> bool:
|
||||
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
||||
self._add_adapter)
|
||||
|
||||
def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None:
|
||||
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
|
||||
self._set_adapter_mapping)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return remove_adapter(adapter_id, self._registered_adapters,
|
||||
self.deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, Any]:
|
||||
return list_adapters(self._registered_adapters)
|
||||
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
return get_adapter(adapter_id, self._registered_adapters)
|
||||
|
||||
|
||||
class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]):
|
||||
|
||||
def __init__(self, capacity: int,
|
||||
deactivate_prompt_adapter_fn: Callable[[int], bool]):
|
||||
super().__init__(capacity, deactivate_prompt_adapter_fn)
|
||||
|
||||
|
||||
class LRUCachePromptAdapterModelManager(PromptAdapterModelManager):
|
||||
"""A model manager that manages multiple prompt_adapters with LRU cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
):
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
prompt_adapter_config)
|
||||
self._registered_adapters = PromptAdapterLRUCache(
|
||||
self.capacity, self.deactivate_adapter)
|
||||
self._active_adapters = PromptAdapterLRUCache(
|
||||
self.prompt_adapter_slots, self._deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, PromptAdapterModel]:
|
||||
"""List all registered PromptAdapterModel."""
|
||||
return dict(self._registered_adapters.cache)
|
||||
|
||||
def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool:
|
||||
"""Add a PromptAdapterModel to the manager."""
|
||||
if prompt_adapter.id not in self._registered_adapters:
|
||||
self._add_adapter(prompt_adapter)
|
||||
was_added = True
|
||||
else:
|
||||
# We always touch to update the LRU cache order
|
||||
self._registered_adapters.touch(prompt_adapter.id)
|
||||
was_added = False
|
||||
return was_added
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
prompt_adapter_id: int,
|
||||
) -> bool:
|
||||
if prompt_adapter_id not in self._active_adapters and len(
|
||||
self._active_adapters) >= self.prompt_adapter_slots:
|
||||
self._active_adapters.remove_oldest()
|
||||
result = super().activate_adapter(prompt_adapter_id)
|
||||
# We always touch to update the LRU cache order
|
||||
self._active_adapters.touch(prompt_adapter_id)
|
||||
return result
|
||||
|
||||
def remove_oldest_adapter(self) -> bool:
|
||||
if len(self._registered_adapters) > 0:
|
||||
self._registered_adapters.remove_oldest()
|
||||
return True
|
||||
return False
|
||||
|
||||
def pin_adapter(self, prompt_adapter_id: int) -> bool:
|
||||
"""Pin a PromptAdapterModel in the manager cache."""
|
||||
self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id)
|
||||
self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id)
|
||||
return True
|
||||
|
||||
def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int):
|
||||
try:
|
||||
self._registered_adapters.pin(prompt_adapter_id)
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
"Pinning failed. "
|
||||
f"Prompt Adapter {prompt_adapter_id} is not registered."
|
||||
) from err
|
||||
|
||||
def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int):
|
||||
if prompt_adapter_id not in self._active_adapters:
|
||||
# move adapter to gpu if not already active
|
||||
self.activate_adapter(prompt_adapter_id)
|
||||
self._active_adapters.pin(prompt_adapter_id)
|
||||
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
prompt_adapter_manager_cls: Type[
|
||||
PromptAdapterModelManager] = PromptAdapterModelManager,
|
||||
**kwargs) -> PromptAdapterModelManager:
|
||||
"""Create a PromptAdapterModel for a given model."""
|
||||
prompt_adapter_manager = prompt_adapter_manager_cls(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
**kwargs)
|
||||
return prompt_adapter_manager
|
||||
37
vllm/prompt_adapter/request.py
Normal file
37
vllm/prompt_adapter/request.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.adapter_commons.request import AdapterRequest
|
||||
|
||||
|
||||
class PromptAdapterRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
frozen=True): # type: ignore[call-arg]
|
||||
"""
|
||||
Request for a Prompt adapter.
|
||||
"""
|
||||
__metaclass__ = AdapterRequest
|
||||
|
||||
prompt_adapter_name: str
|
||||
prompt_adapter_id: int
|
||||
prompt_adapter_local_path: str
|
||||
prompt_adapter_num_virtual_tokens: int
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
|
||||
@property
|
||||
def adapter_id(self):
|
||||
return self.prompt_adapter_id
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.prompt_adapter_name
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
return self.prompt_adapter_local_path
|
||||
98
vllm/prompt_adapter/utils.py
Normal file
98
vllm/prompt_adapter/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import file_exists, hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
WEIGHTS_NAME = "adapter_model.bin"
|
||||
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||
|
||||
|
||||
# Get current device name based on available devices
|
||||
def infer_device() -> str:
|
||||
if current_platform.is_cuda_alike():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
def load_peft_weights(model_id: str,
|
||||
device: Optional[str] = None,
|
||||
**hf_hub_download_kwargs) -> dict:
|
||||
r"""
|
||||
A helper method to load the PEFT weights from the HuggingFace Hub or locally
|
||||
|
||||
Args:
|
||||
model_id (`str`):
|
||||
The local path to the adapter weights or the name of the adapter to
|
||||
load from the HuggingFace Hub.
|
||||
device (`str`):
|
||||
The device to load the weights onto.
|
||||
hf_hub_download_kwargs (`dict`):
|
||||
Additional arguments to pass to the `hf_hub_download` method when
|
||||
loading from the HuggingFace Hub.
|
||||
"""
|
||||
path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) if
|
||||
hf_hub_download_kwargs.get("subfolder") is not None else model_id)
|
||||
|
||||
if device is None:
|
||||
device = infer_device()
|
||||
|
||||
if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
|
||||
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
|
||||
use_safetensors = True
|
||||
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
|
||||
filename = os.path.join(path, WEIGHTS_NAME)
|
||||
use_safetensors = False
|
||||
else:
|
||||
token = hf_hub_download_kwargs.get("token")
|
||||
if token is None:
|
||||
token = hf_hub_download_kwargs.get("use_auth_token")
|
||||
|
||||
hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"],
|
||||
SAFETENSORS_WEIGHTS_NAME)
|
||||
if hf_hub_download_kwargs.get("subfolder") is not None
|
||||
else SAFETENSORS_WEIGHTS_NAME)
|
||||
has_remote_safetensors_file = file_exists(
|
||||
repo_id=model_id,
|
||||
filename=hub_filename,
|
||||
revision=hf_hub_download_kwargs.get("revision"),
|
||||
repo_type=hf_hub_download_kwargs.get("repo_type"),
|
||||
token=token,
|
||||
)
|
||||
use_safetensors = has_remote_safetensors_file
|
||||
|
||||
if has_remote_safetensors_file:
|
||||
# Priority 1: load safetensors weights
|
||||
filename = hf_hub_download(
|
||||
model_id,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
**hf_hub_download_kwargs,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
filename = hf_hub_download(model_id, WEIGHTS_NAME,
|
||||
**hf_hub_download_kwargs)
|
||||
except EntryNotFoundError:
|
||||
raise ValueError( # noqa: B904
|
||||
f"Can't find weights for {model_id} in {model_id} or \
|
||||
in the Hugging Face Hub. "
|
||||
f"Please check that the file {WEIGHTS_NAME} or \
|
||||
{SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.")
|
||||
|
||||
if use_safetensors:
|
||||
adapters_weights = safe_load_file(filename, device=device)
|
||||
else:
|
||||
adapters_weights = torch.load(filename,
|
||||
map_location=torch.device(device),
|
||||
weights_only=True)
|
||||
|
||||
return adapters_weights
|
||||
179
vllm/prompt_adapter/worker_manager.py
Normal file
179
vllm/prompt_adapter/worker_manager.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.adapter_commons.utils import (add_adapter_worker,
|
||||
apply_adapters_worker,
|
||||
list_adapters_worker,
|
||||
set_active_adapters_worker)
|
||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
|
||||
PromptAdapterModel,
|
||||
PromptAdapterModelManager,
|
||||
create_prompt_adapter_manager)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerPromptAdapterManager(AbstractWorkerManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Every request, the requested prompt_adapters will be
|
||||
loaded (unless they are already loaded),
|
||||
and every other prompt_adapter will be unloaded."""
|
||||
|
||||
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
|
||||
):
|
||||
self._adapter_manager: PromptAdapterModelManager
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self._prompt_adapter_model_cls = prompt_adapter_model_cls
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
super().__init__(device)
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._manager_cls,
|
||||
)
|
||||
self._adapter_manager = prompt_adapter_manager
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _load_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest
|
||||
) -> PromptAdapterModel:
|
||||
try:
|
||||
prompt_adapter = (
|
||||
self._prompt_adapter_model_cls.from_local_checkpoint(
|
||||
prompt_adapter_request.prompt_adapter_local_path,
|
||||
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
|
||||
num_virtual_tokens=prompt_adapter_request.
|
||||
prompt_adapter_num_virtual_tokens,
|
||||
config=self.prompt_adapter_config,
|
||||
device=str(self.device),
|
||||
))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Loading prompt_adapter "
|
||||
f"{prompt_adapter_request.prompt_adapter_local_path}"
|
||||
f" failed") from e
|
||||
return prompt_adapter
|
||||
|
||||
def add_dummy_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return True
|
||||
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.pin_adapter(adapter_id)
|
||||
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
set_active_adapters_worker(requests, mapping, self._apply_adapters,
|
||||
self._adapter_manager.set_adapter_mapping)
|
||||
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
return add_adapter_worker(adapter_request, self.list_adapters,
|
||||
self._load_adapter,
|
||||
self._adapter_manager.add_adapter,
|
||||
self._adapter_manager.activate_adapter)
|
||||
|
||||
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
|
||||
apply_adapters_worker(adapter_requests, self.list_adapters,
|
||||
self._adapter_manager.adapter_slots,
|
||||
self.remove_adapter, self.add_adapter)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.remove_adapter(adapter_id)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
self._adapter_manager.remove_all_adapters()
|
||||
|
||||
def list_adapters(self) -> Set[int]:
|
||||
return list_adapters_worker(self._adapter_manager.list_adapters)
|
||||
|
||||
|
||||
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Uses an LRU Cache. Every request, the requested
|
||||
prompt_adapters will be loaded (unless they are already loaded)
|
||||
and least recently used prompt_adapters will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_prompt_adapter_manager_cls: Type[
|
||||
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
|
||||
self._adapter_manager: LRUCachePromptAdapterModelManager = (
|
||||
prompt_adapter_manager)
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _apply_adapters(
|
||||
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
|
||||
prompt_adapters_map = {
|
||||
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
|
||||
for prompt_adapter_request in prompt_adapter_requests
|
||||
if prompt_adapter_request
|
||||
}
|
||||
if len(prompt_adapters_map
|
||||
) > self._adapter_manager.prompt_adapter_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested prompt_adapters "
|
||||
f"({len(prompt_adapters_map)}) is greater "
|
||||
"than the number of GPU prompt_adapter slots "
|
||||
f"({self._adapter_manager.prompt_adapter_slots}).")
|
||||
for prompt_adapter in prompt_adapters_map.values():
|
||||
self.add_adapter(prompt_adapter)
|
||||
|
||||
def add_adapter(self,
|
||||
prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
|
||||
):
|
||||
# Remove before we load the new prompt_adapter to save memory
|
||||
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
||||
self._adapter_manager.remove_oldest_adapter()
|
||||
prompt_adapter = self._load_adapter(prompt_adapter_request)
|
||||
loaded = self._adapter_manager.add_adapter(prompt_adapter)
|
||||
else:
|
||||
# If the prompt_adapter is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._adapter_manager.get_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id) is not None
|
||||
self._adapter_manager.activate_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id)
|
||||
return loaded
|
||||
Reference in New Issue
Block a user