Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)

This commit is contained in:
Lifu Huang
2025-07-23 00:32:16 -07:00
committed by GitHub
parent e885bfdc6a
commit 8abd3e77fe
11 changed files with 400 additions and 261 deletions

View File

@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"
import logging
from typing import Dict, Iterable, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
import torch
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
@@ -55,6 +56,7 @@ class LoRAManager:
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
@@ -64,10 +66,6 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -75,7 +73,11 @@ class LoRAManager:
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
# Initialize mutable internal state of the LoRAManager.
self.init_state()
self.init_state(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
lora_paths=lora_paths,
)
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -112,108 +114,87 @@ class LoRAManager:
success=success,
error_message=error_message,
loaded_adapters={
name: config.path for name, config in self.configs.items()
lora_ref.lora_name: lora_ref.lora_path
for lora_ref in self.lora_refs.values()
},
)
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
"""
Load LoRA adapters from the specified paths.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
results = []
for lora_name, lora_path in lora_paths.items():
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
results.append(result)
self.update_state_from_configs()
return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
)
def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
"""
success = True
error_message = ""
if lora_name in self.loras:
success = False
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
assert (
lora_ref.lora_name is not None and lora_ref.lora_path is not None
), "LoRARef must have both lora_name and lora_path set for loading."
assert (
lora_ref.lora_id not in self.loras
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
try:
new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
# load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter
# load weights
self.load_lora_weights(lora_ref)
# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
except Exception as e:
success = False
error_message = (
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
if update_state:
self.update_state_from_configs()
return self.create_lora_update_result(success=True)
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
"""
success = True
error_message = ""
if lora_name in self.loras:
del self.configs[lora_name]
else:
error_message = f"LoRA adapter {lora_name} is not loaded."
success = False
adapter = self.configs.get(lora_ref.lora_id, None)
assert (
adapter is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
self.update_state_from_configs()
try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
except Exception as e:
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
return self.create_lora_update_result(success=True)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
# Load active loras into lora memory pool
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
@@ -233,10 +214,10 @@ class LoRAManager:
weight_indices = [0] * len(forward_batch.lora_paths)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
for i, uid in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
@@ -326,7 +307,7 @@ class LoRAManager:
"""
Update all LoRA modules to associate them with the latest memory buffer.
"""
for layer_id, layer_modules in self.lora_modules.items():
for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items():
if "qkv_proj" in module_name:
module.set_lora_info(
@@ -353,115 +334,94 @@ class LoRAManager:
),
)
def init_state(self):
def init_state(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
"""
Initialize the internal (mutable) state of the LoRAManager.
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
the target modules and max_lora_rank.
"""
# Configs of all active LoRA adapters.
assert lora_paths or (
max_lora_rank is not None and target_modules is not None
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
self.init_lora_adapters(lora_paths)
self.init_lora_shapes(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
)
self.init_lora_weight_names()
self.init_lora_modules()
self.init_memory_pool()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID.
self.configs: Dict[str, LoRAConfig] = {}
# LoRA adapter weights cached in CPU memory.
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
self.loras: Dict[str, LoRAAdapter] = {}
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
# Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {}
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}
if lora_paths:
for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref)
if not result.success:
raise RuntimeError(
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
)
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
# It is initialized lazily when the first LoRA adapter is loaded.
self.memory_pool: Optional[LoRAMemoryPool] = None
def init_lora_shapes(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
):
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
"""
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
Notes:
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
early CY25H2.
"""
if self.memory_pool is None:
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
if self.max_lora_rank is None:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
self.update_lora_weight_names()
self.update_lora_modules()
self.update_memory_buffers()
if target_modules is not None:
self.target_modules = set(target_modules)
else:
# No-op if the memory pool can support the current LoRA configurations.
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# module is changed once FlashInfer backend is deprecated.
assert self.memory_pool.can_support(self.configs.values()), (
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
if max_lora_rank is not None:
self.max_lora_rank = max_lora_rank
else:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
def update_lora_weight_names(self):
def init_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
# Target lora weight names for lora_a and lora_b modules respectively.
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
def update_lora_adapters(self):
def load_lora_weights(self, lora_ref: LoRARef):
"""
Update the LoRA adapters in CPU memory based on the current `self.configs`.
It loads any new adapters that are not already loaded, and unloads any adapters
that are no longer in `self.configs` (e.g., unloaded).
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
"""
# Load new adapter weights to cpu
for name, config in self.configs.items():
if name not in self.loras:
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
lora_adapter = LoRAAdapter(
name,
config,
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in list(self.loras):
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]
lora_adapter = LoRAAdapter(
lora_ref.lora_id,
self.configs[lora_ref.lora_id],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[lora_ref.lora_id] = lora_adapter
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
@@ -472,7 +432,7 @@ class LoRAManager:
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self):
def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
@@ -490,7 +450,12 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def update_lora_modules(self):
def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
{} for _ in range(self.base_hf_config.num_hidden_layers)
]
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
@@ -511,7 +476,6 @@ class LoRAManager:
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name)
if module_name not in self.lora_modules[layer_id]:
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)

View File

@@ -0,0 +1,124 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import asyncio
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
@dataclass(frozen=True, slots=True)
class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
def __post_init__(self):
if self.lora_id is None:
raise ValueError("lora_id cannot be None")
def __str__(self) -> str:
parts = [
f"{f.name}={value}"
for f in fields(self)
if (value := getattr(self, f.name)) is not None
]
return f"{self.__class__.__name__}({', '.join(parts)})"
class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
)
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
async def register(self, lora_ref: LoRARef):
"""
Register a new LoRARef object in the registry.
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref
async def unregister(self, lora_name: str) -> str:
"""
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
return lora_ref.lora_id
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async def _acquire_single(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
# await self._counters[lora_ref.lora_id].increment()
return lora_ref.lora_id
if isinstance(lora_name, str):
lora_id = await _acquire_single(lora_name)
return lora_id
elif isinstance(lora_name, list):
lora_ids = await asyncio.gather(
*[_acquire_single(name) for name in lora_name]
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")

View File

@@ -153,7 +153,7 @@ class LoRAMemoryPool:
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
uid: str,
buffer_id: int,
lora_adapter: LoRAAdapter,
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def load_lora_weight_tensor(
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]