# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Callable from typing import TypeVar import regex as re import torch from torch import nn from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import ( from_layer, from_layer_logits_processor, get_supported_lora_modules, is_moe_model, process_packed_modules_mapping, replace_submodule, ) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils.cache import LRUCache from vllm.utils.platform_utils import is_pin_memory_available logger = init_logger(__name__) T = TypeVar("T") class AdapterLRUCache(LRUCache[int, T]): def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]): super().__init__(capacity) self.deactivate_fn = deactivate_fn def _on_remove(self, key: int, value: T | None): logger.debug("Removing adapter int id: %d", key) self.deactivate_fn(key) return super()._on_remove(key, value) class LoRAModelManager: """A manager that manages multiple LoRA-fine-tuned models.""" def __init__( self, model: SupportsLoRA, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, device: torch.device, ): """Create a LoRAModelManager and adapter for a given model. Args: model: the model to be adapted. max_num_seqs: the maximum number of sequences model can run in a single batch. max_num_batched_tokens: the maximum number of tokens model can run in a single batch. vocab_size: the vocab size of the model. lora_config: the LoRA configuration. """ self.model: SupportsLoRA = model self._registered_adapters: dict[int, LoRAModel] = {} # Dict instead of a set for compatibility with LRUCache. self._active_adapters: dict[int, None] = {} self.adapter_type = "LoRA" self.lora_config = lora_config self.device = device self.max_num_seqs = max_num_seqs assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: list[int | None] = [None] * self.lora_slots self.vocab_size = vocab_size self.punica_wrapper = get_punica_wrapper( max_num_batched_tokens, max_batches=self.max_num_seqs, device=self.device, max_loras=self.lora_config.max_loras, ) self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, "No supported LoRA modules found in" f" {self.model.__class__.__name__}." self.packed_modules_mapping = process_packed_modules_mapping(self.model) # Used to indicate whether the model is a multimodal model self.supports_mm: bool = ( supports_multimodal(self.model) # In case the model only supports LoRA for # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping") ) self.is_pooling_model = is_pooling_model(self.model) self.packed_modules: dict[str, list[str]] = {} self.modules: dict[str, BaseLayerWithLoRA] = {} # Dict instead of a set for compatibility with LRUCache. self._last_mapping: LoRAMapping | None = None self._is_3d_moe_model = is_moe_model(self.model) and self.model.is_3d_moe_weight self._create_lora_modules() self.model.lora_manager = self def __len__(self) -> int: return len(self._registered_adapters) @property def capacity(self) -> int: return self.lora_config.max_cpu_loras @property def lora_slots(self) -> int: return self.lora_config.max_loras @property def adapter_slots(self) -> int: return self.lora_slots def activate_adapter( self, lora_id: int, ) -> bool: """Move LoRA into a GPU buffer to be used in the forward pass.""" if lora_id in self._active_adapters: return False first_free_slot = next( ( (i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) if lora_id is None ), None, ) if first_free_slot is None: raise ValueError("No free lora slots") index, _ = first_free_slot self._active_adapters[lora_id] = None lora_model = self._registered_adapters[lora_id] logger.debug( "Activating LoRA. int id: %d, slot index: %d", lora_model.id, index ) self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): module_lora = self._get_lora_layer_weights(lora_model, module_name) if not module_lora: module.reset_lora(index) continue # Note (gnovack) - If MOE lora weights are not split into # num_experts chunks, we split them here if isinstance(module, FusedMoE3DWithLoRA) and torch.is_tensor( module_lora.lora_a ): # Handle PEFT file format where experts.base_layer is the # gate_up_proj and experts is the down_proj gate_up_proj_lora = self._get_lora_layer_weights( lora_model, module_name + ".base_layer" ) down_proj_lora = module_lora # FIXME Edge case where LoRA is not added to gate_up_proj # or down_proj assert gate_up_proj_lora is not None assert down_proj_lora is not None if self._is_3d_moe_model: module_lora.lora_a = [ gate_up_proj_lora.lora_a, down_proj_lora.lora_a, ] module_lora.lora_b = [ gate_up_proj_lora.lora_b, down_proj_lora.lora_b, ] else: # Some 3D MoE models haven't added the `is_3d_moe_weight` # attribute yet, so fallback here num_experts = module_lora.lora_a.shape[0] // module_lora.rank gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0) gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk( num_experts, dim=-1 ) up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk( num_experts, dim=-1 ) down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0) down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1) lora_a = [] lora_b = [] for i in range(num_experts): lora_a.append(gate_proj_a[i]) lora_a.append(down_proj_a[i]) lora_a.append(up_proj_a[i]) lora_b.append(gate_proj_b[i]) lora_b.append(down_proj_b[i]) lora_b.append(up_proj_b[i]) module_lora.lora_a = lora_a module_lora.lora_b = lora_b module.set_lora( index, module_lora.lora_a, module_lora.lora_b, ) return True def _deactivate_adapter(self, lora_id: int): try: index = self.lora_index_to_id.index(lora_id) self.lora_index_to_id[index] = None except ValueError: pass def _add_adapter(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_adapters[lora.id] = lora def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in LoRAModelManager. " "Use LRUCacheLoRAModelManager for pinning" ) # type: ignore def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: # update lora states self.punica_wrapper.update_metadata( mapping, self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, ) def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" self._registered_adapters.clear() self.lora_index_to_id = [None] * self.lora_slots self._active_adapters.clear() def _create_lora_modules(self): def _parent_module(module_name: str) -> str: # module name is a dot separated name. # for example: # - given an input 'x.y.z' return 'x.y' # - given an input 'x' return '' return module_name.rpartition(".")[0] for module_name, module in self.model.named_modules(remove_duplicate=False): if isinstance(module, PPMissingLayer): continue if not self._match_target_modules(module_name): continue # A temporary approach for multimodal models to support LoRA # TODO: Remove this restriction if self._filter_unsupported_mm_module(module_name): logger.warning( "Regarding multimodal models, vLLM currently only supports " "adding LoRA to language model, %s will be ignored.", module_name, ) continue parts = module_name.split(".")[-1] packed_moduled_lst = self.packed_modules_mapping.get(parts, []) if isinstance(module, FusedMoE): # packed_moduled_lst is used here to just determine whether to # instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the # difference between these two LoRA layers is whether the # LoRA weights of w1 and w3 have already been fused on disk. packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"] new_module = replace_submodule( self.model, module_name, from_layer( module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config, ), ) # (yard1): TODO make this more robust if "lm_head" in module_name: logits_processor_module_name = "logits_processor" parent_module = _parent_module(module_name) if parent_module: logits_processor_module_name = ( f"{parent_module}.{logits_processor_module_name}" ) logits_processor_module = self.model.get_submodule( logits_processor_module_name ) new_module = replace_submodule( self.model, logits_processor_module_name, from_layer_logits_processor( logits_processor_module, module, self.lora_slots, self.lora_config, self.model.config, ), ) # In some models, especially multimodal ones, layers with the same # name may have different types, such as nn.Linear and # ReplicatedLinear. The nn.Linear layers cannot be replaced with # LoRA layers, leading to assertion error. The following check # aims to prevent this error if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): continue self.register_module(module_name, new_module) self._register_packed_modules(module_name) # All lora layers share the same punica_wrapper based on reference. new_module.set_mapping(self.punica_wrapper) pass def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA), ( f"Module {module_name} must be a BaseLayerWithLoRA instance, " f"got {type(module)}" ) self.modules[module_name] = module def create_dummy_lora( self, lora_id: int, rank: int, embedding_modules: dict[str, str] | None = None, ) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): if ( not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) or self._filter_unsupported_mm_module(module_name) ): continue parts = module_name.split(".") if module_name not in self.packed_modules: assert embedding_modules is not None if parts[-1] in embedding_modules: input_dim = ( module.base_layer.org_vocab_size if hasattr(module.base_layer, "org_vocab_size") else module.base_layer.weight.shape[1] ) output_dim = ( module.base_layer.embedding_dim if hasattr(module.base_layer, "embedding_dim") else module.base_layer.weight.shape[0] ) lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, output_dim, rank, module.lora_a_stacked[0].dtype, "cpu", ) model.loras[module_name] = lora elif module.__class__.__name__ == "FusedMoE3DWithLoRA": # Case for 3D moe model # w2 lora = LoRALayerWeights.create_dummy_lora_weights( module_name, module.w2_input_size, module.w2_output_size, rank * module.w2_lora_a_stacked[0].shape[1], # rank*num_experts module.w2_lora_a_stacked[0].dtype, "cpu", ) model.loras[module_name] = lora # w13 lora = LoRALayerWeights.create_dummy_lora_weights( module_name, module.w13_input_size, module.w13_output_size, rank * module.w13_lora_a_stacked[0].shape[1], # rank*num_experts module.w13_lora_a_stacked[0].dtype, "cpu", ) model.loras[module_name + ".base_layer"] = lora else: lora = LoRALayerWeights.create_dummy_lora_weights( module_name, module.lora_a_stacked[0].shape[-1], module.lora_b_stacked[0].shape[-2], rank, module.lora_a_stacked[0].dtype, "cpu", ) model.loras[module_name] = lora else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] subloras: list[LoRALayerWeights | None] = [] for i, r in enumerate(replacements): lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, module.lora_a_stacked[i].shape[-1], module.lora_b_stacked[i].shape[-2], rank, module.lora_a_stacked[i].dtype, "cpu", ) subloras.append(lora) if module.__class__.__name__ == "FusedMoEWithLoRA": lora = PackedLoRALayerWeights.pack_moe(subloras, module_name) else: lora = PackedLoRALayerWeights.pack(subloras) model.loras[module_name] = lora return model def _match_target_modules(self, module_name: str): return any( re.match( r".*\.{target_module}$".format(target_module=target_module), module_name ) or target_module == module_name for target_module in self.supported_lora_modules ) def _filter_unsupported_mm_module(self, module_name: str) -> bool: """ Regarding multimodal models, vLLM currently only supports adding LoRA to language model. LoRA for other modules, such as the vision tower, will be filtered out. """ if self.supports_mm: module_mapping: MultiModelKeys = self.model.get_mm_mapping() prefix_lst = module_mapping.connector + module_mapping.tower_model return any([module_name.startswith(prefix) for prefix in prefix_lst]) return False def _register_packed_modules(self, module_full_name: str) -> None: parts = module_full_name.split(".") module_name = parts[-1] replacements = self.packed_modules_mapping.get(module_name, []) # When replacements is less than or equal to 1, it indicates that this # module is not a packed module. if len(replacements) <= 1: return prefix = ".".join(parts[:-1]) self.packed_modules[module_full_name] = [ prefix + "." + r if prefix else r for r in replacements ] def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: for module_name, new_module_names in self.packed_modules.items(): replacement_loras: list[LoRALayerWeights | None] = [] replaced_module: set[str] = set() has_replacement = False for r in new_module_names: lora = self._get_lora_layer_weights(lora_model, r) replacement_loras.append(lora) if lora: has_replacement = True replaced_module.add(r) if not has_replacement: continue for i in range(len(replacement_loras)): if replacement_loras[i]: continue replacement_loras[i] = None # HACK Temporary solution for the pool model. if self.is_pooling_model and not lora_model.check_lora_name(module_name): replaced_module_name = module_name.replace("model.", "") if lora_model.check_lora_name(module_name): module_name = replaced_module_name if module_name.endswith(".experts"): lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe( replacement_loras, module_name ) else: lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras ) # Remove the modules that have been replaced. for module in replaced_module: lora_model.loras.pop(module, None) for lora in lora_model.loras.values(): lora.optimize() first_lora: LoRALayerWeights = next(iter(lora_model.loras.values())) assert first_lora.lora_a is not None if isinstance(first_lora.lora_a, list): lora_device = next(iter(first_lora.lora_a)) else: lora_device = first_lora.lora_a.device # Execute pin_memory after LoRA weight merging, mainly because: # 1. Some MoE models have a large number of LoRA weights. If we # perform # pin_memory immediately after loading weights, the # overhead is significant. # 2. The weight packing above (e.g., pack_moe) may invalidate the # pin_memory allocation, so we execute it after packing. pin_memory = str(lora_device) == "cpu" and is_pin_memory_available() if pin_memory: for lora in lora_model.loras.values(): if isinstance(lora.lora_a, list): for index in range(len(lora.lora_a)): if lora.lora_a[index] is None: continue lora.lora_a[index] = lora.lora_a[index].pin_memory() lora.lora_b[index] = lora.lora_b[index].pin_memory() else: lora.lora_a = lora.lora_a.pin_memory() lora.lora_b = lora.lora_b.pin_memory() def _get_lora_layer_weights( self, lora_model: LoRAModel, module_name: str ) -> LoRALayerWeights | None: org_module_name = module_name if self.is_pooling_model and not lora_model.check_lora_name(module_name): # If it's a pool model, and the layer name is not found, # remove the prefix 'model.' and search again. module_name = module_name.replace("model.", "") if lora_model.check_lora_name(module_name): org_module_name = module_name logger.info_once( "For the pool model, successfully loaded the LoRA weights " "after removing the prefix 'model.'." ) return lora_model.get_lora(org_module_name) def deactivate_adapter(self, adapter_id: int) -> bool: if adapter_id not in self._active_adapters: return False self._deactivate_adapter(adapter_id) self._active_adapters.pop(adapter_id, None) return True def add_adapter(self, adapter: LoRAModel) -> bool: logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id) if adapter.id in self._registered_adapters: return False if len(self._registered_adapters) >= self.capacity: raise RuntimeError("No free adapter slots.") self._add_adapter(adapter) return True def set_adapter_mapping(self, mapping: LoRAMapping) -> None: if self._last_mapping != mapping: self._set_adapter_mapping(mapping) self._last_mapping = mapping def remove_adapter(self, adapter_id: int) -> bool: self.deactivate_adapter(adapter_id) if adapter_id not in self._registered_adapters: return False self._registered_adapters.pop(adapter_id, None) return True def list_adapters(self) -> dict[int, LoRAModel]: return dict(self._registered_adapters) def get_adapter(self, adapter_id: int) -> LoRAModel | None: return self._registered_adapters.get(adapter_id) class LoRALRUCache(AdapterLRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]): super().__init__(capacity, deactivate_lora_fn) class LRUCacheLoRAModelManager(LoRAModelManager): """A model manager that manages multiple LoRAs with LRU cache.""" def __init__( self, model: nn.Module, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, device: torch.device, ): super().__init__( model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device ) self._registered_adapters: LoRALRUCache = LoRALRUCache( self.capacity, self.deactivate_adapter ) self._active_adapters: LoRALRUCache = LoRALRUCache( self.lora_slots, self._deactivate_adapter ) def list_adapters(self) -> dict[int, LoRAModel]: """List all registered LoRAModels.""" return dict(self._registered_adapters.cache) def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id) if lora.id not in self._registered_adapters: self._add_adapter(lora) was_added = True else: # We always touch to update the LRU cache order self._registered_adapters.touch(lora.id) was_added = False return was_added def activate_adapter( self, lora_id: int, ) -> bool: if ( lora_id not in self._active_adapters and len(self._active_adapters) >= self.lora_slots ): self._active_adapters.remove_oldest() result = super().activate_adapter(lora_id) # We always touch to update the LRU cache order self._active_adapters.touch(lora_id) return result def remove_oldest_adapter(self) -> bool: if len(self._registered_adapters) > 0: self._registered_adapters.remove_oldest() return True return False def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" self._pin_lora_in_cpu_cache(lora_id) self._pin_lora_in_gpu_cache(lora_id) return True def _pin_lora_in_cpu_cache(self, lora_id: int): try: self._registered_adapters.pin(lora_id) except ValueError as err: raise ValueError( f"Pinning failed. LoRA {lora_id} is not registered." ) from err def _pin_lora_in_gpu_cache(self, lora_id: int): if lora_id not in self._active_adapters: # move lora to gpu if not already active self.activate_adapter(lora_id) self._active_adapters.pin(lora_id) def create_lora_manager( model: nn.Module, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, lora_config: LoRAConfig, device: torch.device, lora_manager_cls: type[LoRAModelManager] = LoRAModelManager, **kwargs, ) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" if not isinstance(model, SupportsLoRA): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, vocab_size=vocab_size, lora_config=lora_config, device=device, **kwargs, ) return lora_manager