Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
import logging
|
||||
from typing import Dict, Iterable, Optional, Set, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
|
||||
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.lora.lora_registry import LoRARef
|
||||
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
||||
from sglang.srt.lora.utils import (
|
||||
LoRABatchInfo,
|
||||
@@ -55,6 +56,7 @@ class LoRAManager:
|
||||
tp_rank: int = 0,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||
):
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -64,10 +66,6 @@ class LoRAManager:
|
||||
self.device: torch.device = next(self.base_model.parameters()).device
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.max_lora_rank: Optional[int] = max_lora_rank
|
||||
self.target_modules: Optional[Set[str]] = (
|
||||
set(target_modules) if target_modules else None
|
||||
)
|
||||
|
||||
# LoRA backend for running sgemm kernels
|
||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||
@@ -75,7 +73,11 @@ class LoRAManager:
|
||||
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
||||
|
||||
# Initialize mutable internal state of the LoRAManager.
|
||||
self.init_state()
|
||||
self.init_state(
|
||||
max_lora_rank=max_lora_rank,
|
||||
target_modules=target_modules,
|
||||
lora_paths=lora_paths,
|
||||
)
|
||||
|
||||
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
||||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
||||
@@ -112,108 +114,87 @@ class LoRAManager:
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
loaded_adapters={
|
||||
name: config.path for name, config in self.configs.items()
|
||||
lora_ref.lora_name: lora_ref.lora_path
|
||||
for lora_ref in self.lora_refs.values()
|
||||
},
|
||||
)
|
||||
|
||||
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
||||
"""
|
||||
Load LoRA adapters from the specified paths.
|
||||
|
||||
Args:
|
||||
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
||||
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for lora_name, lora_path in lora_paths.items():
|
||||
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
||||
results.append(result)
|
||||
|
||||
self.update_state_from_configs()
|
||||
|
||||
return self.create_lora_update_result(
|
||||
success=all(result.success for result in results),
|
||||
error_message="\n".join(
|
||||
result.error_message for result in results if not result.success
|
||||
),
|
||||
)
|
||||
|
||||
def load_lora_adapter(
|
||||
self, lora_name: str, lora_path: str, update_state: bool = True
|
||||
) -> LoRAUpdateResult:
|
||||
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
||||
"""
|
||||
Load a single LoRA adapter from the specified path.
|
||||
|
||||
Args:
|
||||
lora_name (str): The name of the LoRA adapter.
|
||||
lora_path (str): The file path to the LoRA adapter.
|
||||
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
||||
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
|
||||
"""
|
||||
|
||||
success = True
|
||||
error_message = ""
|
||||
|
||||
if lora_name in self.loras:
|
||||
success = False
|
||||
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
||||
assert (
|
||||
lora_ref.lora_name is not None and lora_ref.lora_path is not None
|
||||
), "LoRARef must have both lora_name and lora_path set for loading."
|
||||
assert (
|
||||
lora_ref.lora_id not in self.loras
|
||||
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
|
||||
|
||||
try:
|
||||
new_adapter = LoRAConfig(lora_path)
|
||||
self.validate_new_adapter(lora_name, new_adapter)
|
||||
self.configs[lora_name] = new_adapter
|
||||
# load configs
|
||||
new_adapter = LoRAConfig(lora_ref.lora_path)
|
||||
self.validate_new_adapter(new_adapter, lora_ref)
|
||||
self.configs[lora_ref.lora_id] = new_adapter
|
||||
|
||||
# load weights
|
||||
self.load_lora_weights(lora_ref)
|
||||
|
||||
# keep metadata for displayed messages
|
||||
self.lora_refs[lora_ref.lora_id] = lora_ref
|
||||
except Exception as e:
|
||||
success = False
|
||||
error_message = (
|
||||
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
||||
return self.create_lora_update_result(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
if update_state:
|
||||
self.update_state_from_configs()
|
||||
return self.create_lora_update_result(success=True)
|
||||
|
||||
return self.create_lora_update_result(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
|
||||
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
|
||||
"""
|
||||
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
||||
"""
|
||||
|
||||
incompatible = self.memory_pool and not self.memory_pool.can_support(
|
||||
lora_config
|
||||
)
|
||||
memory_pool = getattr(self, "memory_pool", None)
|
||||
incompatible = memory_pool and not memory_pool.can_support(lora_config)
|
||||
if incompatible:
|
||||
raise ValueError(
|
||||
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
||||
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
||||
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
|
||||
"included in `--enable_lora_modules`."
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
||||
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
||||
"""
|
||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||
delete the corresponding LoRA modules.
|
||||
"""
|
||||
|
||||
success = True
|
||||
error_message = ""
|
||||
if lora_name in self.loras:
|
||||
del self.configs[lora_name]
|
||||
else:
|
||||
error_message = f"LoRA adapter {lora_name} is not loaded."
|
||||
success = False
|
||||
adapter = self.configs.get(lora_ref.lora_id, None)
|
||||
assert (
|
||||
adapter is not None
|
||||
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
|
||||
|
||||
self.update_state_from_configs()
|
||||
try:
|
||||
del self.configs[lora_ref.lora_id]
|
||||
del self.loras[lora_ref.lora_id]
|
||||
del self.lora_refs[lora_ref.lora_id]
|
||||
except Exception as e:
|
||||
return self.create_lora_update_result(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
return self.create_lora_update_result(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
return self.create_lora_update_result(success=True)
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
# Load active loras into lora memory pool
|
||||
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
|
||||
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
|
||||
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
||||
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
||||
@@ -233,10 +214,10 @@ class LoRAManager:
|
||||
weight_indices = [0] * len(forward_batch.lora_paths)
|
||||
lora_ranks = [0] * self.max_loras_per_batch
|
||||
scalings = [0] * self.max_loras_per_batch
|
||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
if lora_path is not None:
|
||||
lora = self.loras[lora_path]
|
||||
for i, uid in enumerate(forward_batch.lora_paths):
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
||||
if uid is not None:
|
||||
lora = self.loras[uid]
|
||||
lora_ranks[weight_indices[i]] = lora.config.r
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
|
||||
@@ -326,7 +307,7 @@ class LoRAManager:
|
||||
"""
|
||||
Update all LoRA modules to associate them with the latest memory buffer.
|
||||
"""
|
||||
for layer_id, layer_modules in self.lora_modules.items():
|
||||
for layer_id, layer_modules in enumerate(self.lora_modules):
|
||||
for module_name, module in layer_modules.items():
|
||||
if "qkv_proj" in module_name:
|
||||
module.set_lora_info(
|
||||
@@ -353,115 +334,94 @@ class LoRAManager:
|
||||
),
|
||||
)
|
||||
|
||||
def init_state(self):
|
||||
def init_state(
|
||||
self,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the internal (mutable) state of the LoRAManager.
|
||||
|
||||
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
|
||||
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
|
||||
the target modules and max_lora_rank.
|
||||
"""
|
||||
|
||||
# Configs of all active LoRA adapters.
|
||||
assert lora_paths or (
|
||||
max_lora_rank is not None and target_modules is not None
|
||||
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
||||
|
||||
self.init_lora_adapters(lora_paths)
|
||||
self.init_lora_shapes(
|
||||
max_lora_rank=max_lora_rank,
|
||||
target_modules=target_modules,
|
||||
)
|
||||
self.init_lora_weight_names()
|
||||
self.init_lora_modules()
|
||||
self.init_memory_pool()
|
||||
|
||||
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
||||
self.configs: Dict[str, LoRAConfig] = {}
|
||||
|
||||
# LoRA adapter weights cached in CPU memory.
|
||||
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
|
||||
self.loras: Dict[str, LoRAAdapter] = {}
|
||||
|
||||
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
|
||||
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
|
||||
# Mapping from LoRA ID to LoRARef object.
|
||||
self.lora_refs: Dict[str, LoRARef] = {}
|
||||
|
||||
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
|
||||
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
||||
}
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
result = self.load_lora_adapter(lora_ref)
|
||||
if not result.success:
|
||||
raise RuntimeError(
|
||||
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
|
||||
)
|
||||
|
||||
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
|
||||
# It is initialized lazily when the first LoRA adapter is loaded.
|
||||
self.memory_pool: Optional[LoRAMemoryPool] = None
|
||||
def init_lora_shapes(
|
||||
self,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
):
|
||||
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
|
||||
|
||||
def update_state_from_configs(self):
|
||||
"""
|
||||
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
||||
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
||||
"""
|
||||
|
||||
# Loads / unloads LoRA adapters based on the latest configs.
|
||||
self.update_lora_adapters()
|
||||
# Apply the latest LoRA configurations to the internal state for inferencing.
|
||||
self.apply_lora_configs()
|
||||
|
||||
def apply_lora_configs(self):
|
||||
"""
|
||||
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
|
||||
|
||||
Notes:
|
||||
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
|
||||
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
|
||||
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
|
||||
early CY25H2.
|
||||
"""
|
||||
|
||||
if self.memory_pool is None:
|
||||
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
|
||||
if self.target_modules is None:
|
||||
self.target_modules = set()
|
||||
for config in self.configs.values():
|
||||
self.target_modules.update(config.target_modules)
|
||||
|
||||
if self.max_lora_rank is None:
|
||||
self.max_lora_rank = max(
|
||||
[x.hf_config["r"] for x in self.configs.values()],
|
||||
default=0,
|
||||
)
|
||||
|
||||
self.update_lora_weight_names()
|
||||
self.update_lora_modules()
|
||||
self.update_memory_buffers()
|
||||
if target_modules is not None:
|
||||
self.target_modules = set(target_modules)
|
||||
else:
|
||||
# No-op if the memory pool can support the current LoRA configurations.
|
||||
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
|
||||
# module is changed once FlashInfer backend is deprecated.
|
||||
assert self.memory_pool.can_support(self.configs.values()), (
|
||||
"LoRA memory pool cannot support the current LoRA configuration. "
|
||||
"This should never happen as we should have validated adapter compatibility. "
|
||||
"Please create a Github issue to report.",
|
||||
self.target_modules = set()
|
||||
for config in self.configs.values():
|
||||
self.target_modules.update(config.target_modules)
|
||||
|
||||
if max_lora_rank is not None:
|
||||
self.max_lora_rank = max_lora_rank
|
||||
else:
|
||||
self.max_lora_rank = max(
|
||||
[x.hf_config["r"] for x in self.configs.values()],
|
||||
default=0,
|
||||
)
|
||||
|
||||
def update_lora_weight_names(self):
|
||||
def init_lora_weight_names(self):
|
||||
"""
|
||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||
"""
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
||||
self.lora_weight_names[0].update(lora_A)
|
||||
self.lora_weight_names[1].update(lora_B)
|
||||
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
|
||||
|
||||
def update_lora_adapters(self):
|
||||
def load_lora_weights(self, lora_ref: LoRARef):
|
||||
"""
|
||||
Update the LoRA adapters in CPU memory based on the current `self.configs`.
|
||||
It loads any new adapters that are not already loaded, and unloads any adapters
|
||||
that are no longer in `self.configs` (e.g., unloaded).
|
||||
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
||||
"""
|
||||
|
||||
# Load new adapter weights to cpu
|
||||
for name, config in self.configs.items():
|
||||
if name not in self.loras:
|
||||
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
|
||||
lora_adapter = LoRAAdapter(
|
||||
name,
|
||||
config,
|
||||
self.base_hf_config,
|
||||
self.load_config,
|
||||
self.lora_backend,
|
||||
)
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[name] = lora_adapter
|
||||
|
||||
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
||||
for name in list(self.loras):
|
||||
if name not in self.configs:
|
||||
logger.info(f"Unloading LoRA adapter {name}")
|
||||
del self.loras[name]
|
||||
lora_adapter = LoRAAdapter(
|
||||
lora_ref.lora_id,
|
||||
self.configs[lora_ref.lora_id],
|
||||
self.base_hf_config,
|
||||
self.load_config,
|
||||
self.lora_backend,
|
||||
)
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[lora_ref.lora_id] = lora_adapter
|
||||
|
||||
# Additional checks for flashinfer backend
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
@@ -472,7 +432,7 @@ class LoRAManager:
|
||||
len(lora_dims) == 1 and len(scalings) == 1
|
||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||
|
||||
def update_memory_buffers(self):
|
||||
def init_memory_pool(self):
|
||||
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
base_hf_config=self.base_hf_config,
|
||||
@@ -490,7 +450,12 @@ class LoRAManager:
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def update_lora_modules(self):
|
||||
def init_lora_modules(self):
|
||||
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
||||
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
||||
]
|
||||
|
||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||
# e.g., {"qkv_proj", "o_proj"}
|
||||
customized_target_names = get_customized_names_from_hf_names(
|
||||
@@ -511,7 +476,6 @@ class LoRAManager:
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in customized_target_names:
|
||||
layer_id = get_layer_id(module_name)
|
||||
if module_name not in self.lora_modules[layer_id]:
|
||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||
module_name, module
|
||||
)
|
||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||
module_name, module
|
||||
)
|
||||
|
||||
124
python/sglang/srt/lora/lora_registry.py
Normal file
124
python/sglang/srt/lora/lora_registry.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LoRARef:
|
||||
"""
|
||||
Reference record for a LoRA model.
|
||||
|
||||
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
|
||||
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
||||
keys (e.g., radix cache).
|
||||
"""
|
||||
|
||||
lora_id: str = field(default_factory=lambda: uuid4().hex)
|
||||
lora_name: Optional[str] = None
|
||||
lora_path: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lora_id is None:
|
||||
raise ValueError("lora_id cannot be None")
|
||||
|
||||
def __str__(self) -> str:
|
||||
parts = [
|
||||
f"{f.name}={value}"
|
||||
for f in fields(self)
|
||||
if (value := getattr(self, f.name)) is not None
|
||||
]
|
||||
return f"{self.__class__.__name__}({', '.join(parts)})"
|
||||
|
||||
|
||||
class LoRARegistry:
|
||||
"""
|
||||
The central registry to keep track of available LoRA adapters.
|
||||
|
||||
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
|
||||
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||
assert lora_paths is None or all(
|
||||
isinstance(lora, LoRARef) for lora in lora_paths.values()
|
||||
), (
|
||||
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
||||
"Please file an issue if you see this error."
|
||||
)
|
||||
|
||||
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
||||
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
|
||||
|
||||
async def register(self, lora_ref: LoRARef):
|
||||
"""
|
||||
Register a new LoRARef object in the registry.
|
||||
|
||||
Args:
|
||||
lora_ref (LoRARef): The LoRARef object to register.
|
||||
"""
|
||||
if lora_ref.lora_name in self._registry:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
self._registry[lora_ref.lora_name] = lora_ref
|
||||
|
||||
async def unregister(self, lora_name: str) -> str:
|
||||
"""
|
||||
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
|
||||
|
||||
Args:
|
||||
lora_name (str): The name of the LoRA model to unregister.
|
||||
"""
|
||||
lora_ref = self._registry.get(lora_name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||
)
|
||||
del self._registry[lora_name]
|
||||
|
||||
return lora_ref.lora_id
|
||||
|
||||
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
|
||||
"""
|
||||
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
||||
by incrementing its counter.
|
||||
|
||||
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
|
||||
"""
|
||||
|
||||
async def _acquire_single(name: str) -> str:
|
||||
lora_ref = self._registry.get(name, None)
|
||||
if lora_ref is None:
|
||||
raise ValueError(
|
||||
f"The following requested LoRA adapters are not loaded: {name}\n"
|
||||
f"Loaded adapters: {self._registry.keys()}."
|
||||
)
|
||||
# await self._counters[lora_ref.lora_id].increment()
|
||||
return lora_ref.lora_id
|
||||
|
||||
if isinstance(lora_name, str):
|
||||
lora_id = await _acquire_single(lora_name)
|
||||
return lora_id
|
||||
elif isinstance(lora_name, list):
|
||||
lora_ids = await asyncio.gather(
|
||||
*[_acquire_single(name) for name in lora_name]
|
||||
)
|
||||
return lora_ids
|
||||
else:
|
||||
raise TypeError("lora_name must be either a string or a list of strings.")
|
||||
@@ -153,7 +153,7 @@ class LoRAMemoryPool:
|
||||
self,
|
||||
cur_uids: Set[Optional[str]],
|
||||
lora_adapters: Dict[str, LoRAAdapter],
|
||||
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
||||
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
||||
):
|
||||
def get_available_buffer_slot():
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
|
||||
uid: str,
|
||||
buffer_id: int,
|
||||
lora_adapter: LoRAAdapter,
|
||||
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
||||
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
||||
):
|
||||
def load_lora_weight_tensor(
|
||||
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
|
||||
|
||||
Reference in New Issue
Block a user