[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
179
vllm/prompt_adapter/worker_manager.py
Normal file
179
vllm/prompt_adapter/worker_manager.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.adapter_commons.utils import (add_adapter_worker,
|
||||
apply_adapters_worker,
|
||||
list_adapters_worker,
|
||||
set_active_adapters_worker)
|
||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||
from vllm.config import PromptAdapterConfig
|
||||
from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager,
|
||||
PromptAdapterModel,
|
||||
PromptAdapterModelManager,
|
||||
create_prompt_adapter_manager)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkerPromptAdapterManager(AbstractWorkerManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Every request, the requested prompt_adapters will be
|
||||
loaded (unless they are already loaded),
|
||||
and every other prompt_adapter will be unloaded."""
|
||||
|
||||
_manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
prompt_adapter_config: PromptAdapterConfig,
|
||||
prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel
|
||||
):
|
||||
self._adapter_manager: PromptAdapterModelManager
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self._prompt_adapter_model_cls = prompt_adapter_model_cls
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
super().__init__(device)
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._manager_cls,
|
||||
)
|
||||
self._adapter_manager = prompt_adapter_manager
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _load_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest
|
||||
) -> PromptAdapterModel:
|
||||
try:
|
||||
prompt_adapter = (
|
||||
self._prompt_adapter_model_cls.from_local_checkpoint(
|
||||
prompt_adapter_request.prompt_adapter_local_path,
|
||||
prompt_adapter_id=prompt_adapter_request.prompt_adapter_id,
|
||||
num_virtual_tokens=prompt_adapter_request.
|
||||
prompt_adapter_num_virtual_tokens,
|
||||
config=self.prompt_adapter_config,
|
||||
device=str(self.device),
|
||||
))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Loading prompt_adapter "
|
||||
f"{prompt_adapter_request.prompt_adapter_local_path}"
|
||||
f" failed") from e
|
||||
return prompt_adapter
|
||||
|
||||
def add_dummy_prompt_adapter(
|
||||
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
return True
|
||||
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.pin_adapter(adapter_id)
|
||||
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
set_active_adapters_worker(requests, mapping, self._apply_adapters,
|
||||
self._adapter_manager.set_adapter_mapping)
|
||||
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
return add_adapter_worker(adapter_request, self.list_adapters,
|
||||
self._load_adapter,
|
||||
self._adapter_manager.add_adapter,
|
||||
self._adapter_manager.activate_adapter)
|
||||
|
||||
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
|
||||
apply_adapters_worker(adapter_requests, self.list_adapters,
|
||||
self._adapter_manager.adapter_slots,
|
||||
self.remove_adapter, self.add_adapter)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.remove_adapter(adapter_id)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
self._adapter_manager.remove_all_adapters()
|
||||
|
||||
def list_adapters(self) -> Set[int]:
|
||||
return list_adapters_worker(self._adapter_manager.list_adapters)
|
||||
|
||||
|
||||
class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager):
|
||||
"""WorkerPromptAdapterManager that manages
|
||||
prompt_adapter models on the worker side.
|
||||
|
||||
Uses an LRU Cache. Every request, the requested
|
||||
prompt_adapters will be loaded (unless they are already loaded)
|
||||
and least recently used prompt_adapters will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_prompt_adapter_manager_cls: Type[
|
||||
LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager
|
||||
|
||||
def create_prompt_adapter_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
prompt_adapter_manager = create_prompt_adapter_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
prompt_adapter_manager_cls=self._prompt_adapter_manager_cls)
|
||||
self._adapter_manager: LRUCachePromptAdapterModelManager = (
|
||||
prompt_adapter_manager)
|
||||
return prompt_adapter_manager.model
|
||||
|
||||
def _apply_adapters(
|
||||
self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None:
|
||||
prompt_adapters_map = {
|
||||
prompt_adapter_request.prompt_adapter_id: prompt_adapter_request
|
||||
for prompt_adapter_request in prompt_adapter_requests
|
||||
if prompt_adapter_request
|
||||
}
|
||||
if len(prompt_adapters_map
|
||||
) > self._adapter_manager.prompt_adapter_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested prompt_adapters "
|
||||
f"({len(prompt_adapters_map)}) is greater "
|
||||
"than the number of GPU prompt_adapter slots "
|
||||
f"({self._adapter_manager.prompt_adapter_slots}).")
|
||||
for prompt_adapter in prompt_adapters_map.values():
|
||||
self.add_adapter(prompt_adapter)
|
||||
|
||||
def add_adapter(self,
|
||||
prompt_adapter_request: PromptAdapterRequest) -> bool:
|
||||
if prompt_adapter_request.prompt_adapter_id not in self.list_adapters(
|
||||
):
|
||||
# Remove before we load the new prompt_adapter to save memory
|
||||
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
||||
self._adapter_manager.remove_oldest_adapter()
|
||||
prompt_adapter = self._load_adapter(prompt_adapter_request)
|
||||
loaded = self._adapter_manager.add_adapter(prompt_adapter)
|
||||
else:
|
||||
# If the prompt_adapter is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._adapter_manager.get_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id) is not None
|
||||
self._adapter_manager.activate_adapter(
|
||||
prompt_adapter_request.prompt_adapter_id)
|
||||
return loaded
|
||||
Reference in New Issue
Block a user