178 lines
6.8 KiB
Python
178 lines
6.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Define LoRA functionality mixin for model runners.
|
|
"""
|
|
|
|
from contextlib import contextmanager
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
|
|
from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.layers import LoRAMapping
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
|
from vllm.model_executor.models import supports_lora, supports_multimodal
|
|
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
|
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
|
|
|
InputBatch = Union[TPUInputBatch, GPUInputBatch]
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
# Defined as a mixin for GPUModelRunner
|
|
class LoRAModelRunnerMixin:
|
|
|
|
LORA_WARMUP_RANK = 8
|
|
|
|
def load_lora_model(self, model: nn.Module, model_config: ModelConfig,
|
|
scheduler_config: SchedulerConfig,
|
|
lora_config: LoRAConfig, device: str) -> nn.Module:
|
|
|
|
if not supports_lora(model):
|
|
raise ValueError(
|
|
f"{model.__class__.__name__} does not support LoRA yet.")
|
|
|
|
if supports_multimodal(model):
|
|
logger.warning("Regarding multimodal models, vLLM currently "
|
|
"only supports adding LoRA to language model.")
|
|
|
|
# Use get_text_config() in case of multimodal models
|
|
text_config = model_config.hf_config.get_text_config()
|
|
|
|
# Add LoRA Manager to the Model Runner
|
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
|
scheduler_config.max_num_seqs,
|
|
scheduler_config.max_num_batched_tokens,
|
|
model_config.get_vocab_size(),
|
|
lora_config,
|
|
device,
|
|
model.embedding_modules,
|
|
model.embedding_padding_modules,
|
|
max_position_embeddings=text_config.max_position_embeddings,
|
|
)
|
|
return self.lora_manager.create_lora_manager(model)
|
|
|
|
def _set_active_loras(self, prompt_lora_mapping: tuple[int, ...],
|
|
token_lora_mapping: tuple[int, ...],
|
|
lora_requests: set[LoRARequest]) -> None:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
|
|
# Set is_prefill to True, so we always use the SGMV kernels on
|
|
# non-cuda platforms.
|
|
# On cuda platforms we use the same kernels for prefill and
|
|
# decode and this flag is generally ignored.
|
|
lora_mapping = LoRAMapping(token_lora_mapping,
|
|
prompt_lora_mapping,
|
|
is_prefill=True)
|
|
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
|
|
|
def set_active_loras(self, input_batch: InputBatch,
|
|
num_scheduled_tokens: np.ndarray) -> None:
|
|
|
|
prompt_lora_mapping: tuple[int, ...] # of size input_batch.num_reqs
|
|
token_lora_mapping: tuple[int,
|
|
...] # of size np.sum(num_scheduled_tokens)
|
|
lora_requests: set[LoRARequest]
|
|
prompt_lora_mapping, token_lora_mapping, lora_requests = \
|
|
input_batch.make_lora_inputs(num_scheduled_tokens)
|
|
return self._set_active_loras(prompt_lora_mapping, token_lora_mapping,
|
|
lora_requests)
|
|
|
|
@contextmanager
|
|
def maybe_setup_dummy_loras(self, lora_config):
|
|
if lora_config is None:
|
|
yield
|
|
else:
|
|
# __enter__ code
|
|
assert self.lora_manager is not None, "LoRA is not enabled"
|
|
|
|
num_loras = lora_config.max_loras
|
|
|
|
# Make dummy lora requests
|
|
lora_requests: set[LoRARequest] = {
|
|
LoRARequest(lora_name=f"warmup_{lora_id}",
|
|
lora_int_id=lora_id,
|
|
lora_path="/not/a/real/path")
|
|
for lora_id in range(1, num_loras + 1)
|
|
}
|
|
|
|
with self.lora_manager.dummy_lora_cache():
|
|
# Add the dummy LoRAs here so _set_active_loras doesn't try to
|
|
# load from disk.
|
|
for lr in lora_requests:
|
|
self.lora_manager.add_dummy_lora(
|
|
lr, rank=self.LORA_WARMUP_RANK)
|
|
|
|
yield
|
|
|
|
# __exit__ code
|
|
self.lora_manager.remove_all_adapters()
|
|
|
|
@contextmanager
|
|
def maybe_select_dummy_loras(self, lora_config: LoRAConfig,
|
|
num_scheduled_tokens: np.ndarray):
|
|
if lora_config is None:
|
|
yield
|
|
else:
|
|
# __enter__ code
|
|
assert self.lora_manager is not None, "LoRA is not enabled"
|
|
|
|
num_reqs = len(num_scheduled_tokens)
|
|
num_loras = lora_config.max_loras
|
|
|
|
# Make prompt lora mapping
|
|
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
|
|
prompt_lora_mapping = (np.arange(num_reqs, dtype=np.int32) %
|
|
num_loras) + 1
|
|
|
|
# Make token lora mapping
|
|
token_lora_mapping = np.repeat(prompt_lora_mapping,
|
|
num_scheduled_tokens)
|
|
|
|
# Make dummy lora requests
|
|
lora_requests: set[LoRARequest] = {
|
|
LoRARequest(lora_name=f"warmup_{lora_id}",
|
|
lora_int_id=lora_id,
|
|
lora_path="/not/a/real/path")
|
|
for lora_id in range(1, num_loras + 1)
|
|
}
|
|
|
|
self._set_active_loras(tuple(prompt_lora_mapping),
|
|
tuple(token_lora_mapping), lora_requests)
|
|
|
|
yield
|
|
|
|
@contextmanager
|
|
def maybe_dummy_run_with_lora(self, lora_config: LoRAConfig,
|
|
num_scheduled_tokens: np.ndarray):
|
|
with self.maybe_setup_dummy_loras(
|
|
lora_config), self.maybe_select_dummy_loras(
|
|
lora_config, num_scheduled_tokens):
|
|
yield
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.add_adapter(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.remove_adapter(lora_id)
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.pin_adapter(lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
if not self.lora_manager:
|
|
raise RuntimeError("LoRA is not enabled.")
|
|
return self.lora_manager.list_adapters()
|