Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)

This commit is contained in:
Lifu Huang
2025-07-23 00:32:16 -07:00
committed by GitHub
parent e885bfdc6a
commit 8abd3e77fe
11 changed files with 400 additions and 261 deletions

View File

@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"
import logging
from typing import Dict, Iterable, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
import torch
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
@@ -55,6 +56,7 @@ class LoRAManager:
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
@@ -64,10 +66,6 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -75,7 +73,11 @@ class LoRAManager:
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
# Initialize mutable internal state of the LoRAManager.
self.init_state()
self.init_state(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
lora_paths=lora_paths,
)
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -112,108 +114,87 @@ class LoRAManager:
success=success,
error_message=error_message,
loaded_adapters={
name: config.path for name, config in self.configs.items()
lora_ref.lora_name: lora_ref.lora_path
for lora_ref in self.lora_refs.values()
},
)
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
"""
Load LoRA adapters from the specified paths.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
results = []
for lora_name, lora_path in lora_paths.items():
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
results.append(result)
self.update_state_from_configs()
return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
)
def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
"""
success = True
error_message = ""
if lora_name in self.loras:
success = False
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
assert (
lora_ref.lora_name is not None and lora_ref.lora_path is not None
), "LoRARef must have both lora_name and lora_path set for loading."
assert (
lora_ref.lora_id not in self.loras
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
try:
new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
# load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter
# load weights
self.load_lora_weights(lora_ref)
# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
except Exception as e:
success = False
error_message = (
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
if update_state:
self.update_state_from_configs()
return self.create_lora_update_result(success=True)
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
"""
success = True
error_message = ""
if lora_name in self.loras:
del self.configs[lora_name]
else:
error_message = f"LoRA adapter {lora_name} is not loaded."
success = False
adapter = self.configs.get(lora_ref.lora_id, None)
assert (
adapter is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
self.update_state_from_configs()
try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
except Exception as e:
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
return self.create_lora_update_result(success=True)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
# Load active loras into lora memory pool
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
@@ -233,10 +214,10 @@ class LoRAManager:
weight_indices = [0] * len(forward_batch.lora_paths)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
for i, uid in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
@@ -326,7 +307,7 @@ class LoRAManager:
"""
Update all LoRA modules to associate them with the latest memory buffer.
"""
for layer_id, layer_modules in self.lora_modules.items():
for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items():
if "qkv_proj" in module_name:
module.set_lora_info(
@@ -353,115 +334,94 @@ class LoRAManager:
),
)
def init_state(self):
def init_state(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
"""
Initialize the internal (mutable) state of the LoRAManager.
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
the target modules and max_lora_rank.
"""
# Configs of all active LoRA adapters.
assert lora_paths or (
max_lora_rank is not None and target_modules is not None
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
self.init_lora_adapters(lora_paths)
self.init_lora_shapes(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
)
self.init_lora_weight_names()
self.init_lora_modules()
self.init_memory_pool()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID.
self.configs: Dict[str, LoRAConfig] = {}
# LoRA adapter weights cached in CPU memory.
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
self.loras: Dict[str, LoRAAdapter] = {}
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
# Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {}
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}
if lora_paths:
for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref)
if not result.success:
raise RuntimeError(
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
)
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
# It is initialized lazily when the first LoRA adapter is loaded.
self.memory_pool: Optional[LoRAMemoryPool] = None
def init_lora_shapes(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
):
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
"""
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
Notes:
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
early CY25H2.
"""
if self.memory_pool is None:
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
if self.max_lora_rank is None:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
self.update_lora_weight_names()
self.update_lora_modules()
self.update_memory_buffers()
if target_modules is not None:
self.target_modules = set(target_modules)
else:
# No-op if the memory pool can support the current LoRA configurations.
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# module is changed once FlashInfer backend is deprecated.
assert self.memory_pool.can_support(self.configs.values()), (
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
if max_lora_rank is not None:
self.max_lora_rank = max_lora_rank
else:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
def update_lora_weight_names(self):
def init_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
# Target lora weight names for lora_a and lora_b modules respectively.
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
def update_lora_adapters(self):
def load_lora_weights(self, lora_ref: LoRARef):
"""
Update the LoRA adapters in CPU memory based on the current `self.configs`.
It loads any new adapters that are not already loaded, and unloads any adapters
that are no longer in `self.configs` (e.g., unloaded).
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
"""
# Load new adapter weights to cpu
for name, config in self.configs.items():
if name not in self.loras:
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
lora_adapter = LoRAAdapter(
name,
config,
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in list(self.loras):
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]
lora_adapter = LoRAAdapter(
lora_ref.lora_id,
self.configs[lora_ref.lora_id],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[lora_ref.lora_id] = lora_adapter
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
@@ -472,7 +432,7 @@ class LoRAManager:
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self):
def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
@@ -490,7 +450,12 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def update_lora_modules(self):
def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
{} for _ in range(self.base_hf_config.num_hidden_layers)
]
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
@@ -511,7 +476,6 @@ class LoRAManager:
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name)
if module_name not in self.lora_modules[layer_id]:
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)

View File

@@ -0,0 +1,124 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import asyncio
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
@dataclass(frozen=True, slots=True)
class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
def __post_init__(self):
if self.lora_id is None:
raise ValueError("lora_id cannot be None")
def __str__(self) -> str:
parts = [
f"{f.name}={value}"
for f in fields(self)
if (value := getattr(self, f.name)) is not None
]
return f"{self.__class__.__name__}({', '.join(parts)})"
class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
)
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
async def register(self, lora_ref: LoRARef):
"""
Register a new LoRARef object in the registry.
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref
async def unregister(self, lora_name: str) -> str:
"""
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
return lora_ref.lora_id
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async def _acquire_single(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
# await self._counters[lora_ref.lora_id].increment()
return lora_ref.lora_id
if isinstance(lora_name, str):
lora_id = await _acquire_single(lora_name)
return lora_id
elif isinstance(lora_name, list):
lora_ids = await asyncio.gather(
*[_acquire_single(name) for name in lora_name]
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")

View File

@@ -153,7 +153,7 @@ class LoRAMemoryPool:
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
uid: str,
buffer_id: int,
lora_adapter: LoRAAdapter,
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def load_lora_weight_tensor(
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]

View File

@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
lora_name: str
# The path of loading.
lora_path: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
lora_path=self.lora_path,
)
@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
)
@dataclass
class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict)
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult

View File

@@ -247,7 +247,7 @@ class Scheduler(
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.lora_paths = server_args.lora_paths
self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
@@ -1706,13 +1706,13 @@ class Scheduler(
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths:
if self.enable_lora:
lora_set = set([req.lora_path for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.lora_paths
self.enable_lora
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
@@ -2466,12 +2466,6 @@ class Scheduler(
"""In-place loading a new lora adapter from disk or huggingface."""
result = self.tp_worker.load_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result
def unload_lora_adapter(
@@ -2480,14 +2474,6 @@ class Scheduler(
"""Unload the lora adapter."""
result = self.tp_worker.unload_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):

View File

@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
@@ -242,11 +243,11 @@ class TokenizerManager:
revision=server_args.revision,
)
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self.loaded_lora_adapters: Dict[str, str] = dict(
self.server_args.lora_paths or {}
)
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Store states
self.no_create_loop = False
@@ -523,6 +524,10 @@ class TokenizerManager:
else:
mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
@@ -574,8 +579,6 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.enable_lora and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
@@ -689,21 +692,6 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -1054,8 +1042,18 @@ class TokenizerManager:
)
async with self.model_update_lock.writer_lock:
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
)
# Register the new adapter in the registry.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
if result.success:
await self.lora_registry.register(new_adapter)
return result
async def unload_lora_adapter(
@@ -1069,6 +1067,10 @@ class TokenizerManager:
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
@@ -1080,8 +1082,9 @@ class TokenizerManager:
)
async with self.model_update_lock.writer_lock:
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result
async def get_weights_by_name(
@@ -1309,7 +1312,7 @@ class TokenizerManager:
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
)
os.makedirs(os.path.dirname(filename), exist_ok=True)

View File

@@ -293,11 +293,9 @@ class TpModelWorker:
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result

View File

@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict,
@@ -890,44 +891,38 @@ class ModelRunner:
tp_rank=self.tp_rank,
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
lora_paths=self.server_args.lora_paths,
)
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {})
if result.success:
logger.info(
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
)
else:
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
def load_lora_adapter(self, lora_name: str, lora_path: str):
def load_lora_adapter(self, lora_ref: LoRARef):
"""Load a new lora adapter from disk or huggingface."""
logger.info(
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
f"LoRA adapter loading starts: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
result = self.lora_manager.load_lora_adapter(lora_ref)
logger.info(
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
f"LoRA adapter loading completes: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
return result
def unload_lora_adapter(self, lora_name: str):
def unload_lora_adapter(self, lora_ref: LoRARef):
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
logger.info(
f"LoRA adapter unloading starts: name={lora_name}. "
f"LoRA adapter unloading starts: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.unload_lora_adapter(lora_name)
result = self.lora_manager.unload_lora_adapter(lora_ref)
logger.info(
f"LoRA adapter unloading completes: name={lora_name}. "
f"LoRA adapter unloading completes: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)

View File

@@ -20,10 +20,10 @@ import logging
import os
import random
import tempfile
from token import OP
from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import (
LORA_TARGET_ALL_MODULES,
@@ -145,7 +145,7 @@ class ServerArgs:
enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
@@ -1843,9 +1843,24 @@ class ServerArgs:
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = path
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
else:
self.lora_paths[lora_path] = lora_path
self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path,
lora_path=lora_path,
)
elif isinstance(self.lora_paths, dict):
self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v)
for k, v in self.lora_paths.items()
}
elif self.lora_paths is None:
self.lora_paths = {}
else:
raise ValueError(
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
"Expected a list or a dictionary."
)
# Expand target modules
if self.lora_target_modules: