From 8abd3e77feca9ed740356c1b879e524d09482fb2 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Wed, 23 Jul 2025 00:32:16 -0700 Subject: [PATCH] Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261) --- python/sglang/srt/lora/lora_manager.py | 306 ++++++++---------- python/sglang/srt/lora/lora_registry.py | 124 +++++++ python/sglang/srt/lora/mem_pool.py | 4 +- python/sglang/srt/managers/io_struct.py | 20 +- python/sglang/srt/managers/scheduler.py | 20 +- .../sglang/srt/managers/tokenizer_manager.py | 53 +-- python/sglang/srt/managers/tp_worker.py | 6 +- .../sglang/srt/model_executor/model_runner.py | 25 +- python/sglang/srt/server_args.py | 23 +- test/srt/models/lora/test_lora_eviction.py | 78 +++-- test/srt/run_suite.py | 2 +- 11 files changed, 400 insertions(+), 261 deletions(-) create mode 100644 python/sglang/srt/lora/lora_registry.py diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 85fd24616..719c52ef8 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -16,7 +16,7 @@ # and "Punica: Multi-Tenant LoRA Serving" import logging -from typing import Dict, Iterable, Optional, Set, Tuple +from typing import Dict, Iterable, List, Optional, Set, Tuple import torch @@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.utils import ( LoRABatchInfo, @@ -55,6 +56,7 @@ class LoRAManager: tp_rank: int = 0, max_lora_rank: Optional[int] = None, target_modules: Optional[Iterable[str]] = None, + lora_paths: Optional[Dict[str, LoRARef]] = None, ): self.base_model: torch.nn.Module = base_model self.base_hf_config: AutoConfig = base_hf_config @@ -64,10 +66,6 @@ class LoRAManager: self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank - self.max_lora_rank: Optional[int] = max_lora_rank - self.target_modules: Optional[Set[str]] = ( - set(target_modules) if target_modules else None - ) # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") @@ -75,7 +73,11 @@ class LoRAManager: self.lora_backend: BaseLoRABackend = backend_type(lora_backend) # Initialize mutable internal state of the LoRAManager. - self.init_state() + self.init_state( + max_lora_rank=max_lora_rank, + target_modules=target_modules, + lora_paths=lora_paths, + ) def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): self.max_bs_in_cuda_graph = max_bs_in_cuda_graph @@ -112,108 +114,87 @@ class LoRAManager: success=success, error_message=error_message, loaded_adapters={ - name: config.path for name, config in self.configs.items() + lora_ref.lora_name: lora_ref.lora_path + for lora_ref in self.lora_refs.values() }, ) - def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult: - """ - Load LoRA adapters from the specified paths. - - Args: - lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths. - If a LoRA adapter is already loaded, it will be skipped with a warning. - """ - - results = [] - for lora_name, lora_path in lora_paths.items(): - result = self.load_lora_adapter(lora_name, lora_path, update_state=False) - results.append(result) - - self.update_state_from_configs() - - return self.create_lora_update_result( - success=all(result.success for result in results), - error_message="\n".join( - result.error_message for result in results if not result.success - ), - ) - - def load_lora_adapter( - self, lora_name: str, lora_path: str, update_state: bool = True - ) -> LoRAUpdateResult: + def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: """ Load a single LoRA adapter from the specified path. Args: - lora_name (str): The name of the LoRA adapter. - lora_path (str): The file path to the LoRA adapter. - update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading. + lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID. """ - - success = True - error_message = "" - - if lora_name in self.loras: - success = False - error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first." + assert ( + lora_ref.lora_name is not None and lora_ref.lora_path is not None + ), "LoRARef must have both lora_name and lora_path set for loading." + assert ( + lora_ref.lora_id not in self.loras + ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend." try: - new_adapter = LoRAConfig(lora_path) - self.validate_new_adapter(lora_name, new_adapter) - self.configs[lora_name] = new_adapter + # load configs + new_adapter = LoRAConfig(lora_ref.lora_path) + self.validate_new_adapter(new_adapter, lora_ref) + self.configs[lora_ref.lora_id] = new_adapter + + # load weights + self.load_lora_weights(lora_ref) + + # keep metadata for displayed messages + self.lora_refs[lora_ref.lora_id] = lora_ref except Exception as e: - success = False - error_message = ( - f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}" + return self.create_lora_update_result( + success=False, + error_message=str(e), ) - if update_state: - self.update_state_from_configs() + return self.create_lora_update_result(success=True) - return self.create_lora_update_result( - success=success, - error_message=error_message, - ) - - def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig): + def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef): """ Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible. """ - incompatible = self.memory_pool and not self.memory_pool.can_support( - lora_config - ) + memory_pool = getattr(self, "memory_pool", None) + incompatible = memory_pool and not memory_pool.can_support(lora_config) if incompatible: raise ValueError( - f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " + f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " "included in `--enable_lora_modules`." ) - def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: + def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: """ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and delete the corresponding LoRA modules. """ - success = True - error_message = "" - if lora_name in self.loras: - del self.configs[lora_name] - else: - error_message = f"LoRA adapter {lora_name} is not loaded." - success = False + adapter = self.configs.get(lora_ref.lora_id, None) + assert ( + adapter is not None + ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend." - self.update_state_from_configs() + try: + del self.configs[lora_ref.lora_id] + del self.loras[lora_ref.lora_id] + del self.lora_refs[lora_ref.lora_id] + except Exception as e: + return self.create_lora_update_result( + success=False, + error_message=str(e), + ) - return self.create_lora_update_result( - success=success, - error_message=error_message, - ) + return self.create_lora_update_result(success=True) def prepare_lora_batch(self, forward_batch: ForwardBatch): - # load active loras into lora memory pool + # Load active loras into lora memory pool + # TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique + # LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we + # should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in + # the current API schema and introducing a better request schema in the future (e.g., use `model_name`). cur_uids = set(forward_batch.lora_paths) assert len(cur_uids) <= self.max_loras_per_batch self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) @@ -233,10 +214,10 @@ class LoRAManager: weight_indices = [0] * len(forward_batch.lora_paths) lora_ranks = [0] * self.max_loras_per_batch scalings = [0] * self.max_loras_per_batch - for i, lora_path in enumerate(forward_batch.lora_paths): - weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) - if lora_path is not None: - lora = self.loras[lora_path] + for i, uid in enumerate(forward_batch.lora_paths): + weight_indices[i] = self.memory_pool.get_buffer_id(uid) + if uid is not None: + lora = self.loras[uid] lora_ranks[weight_indices[i]] = lora.config.r scalings[weight_indices[i]] = lora.scaling @@ -326,7 +307,7 @@ class LoRAManager: """ Update all LoRA modules to associate them with the latest memory buffer. """ - for layer_id, layer_modules in self.lora_modules.items(): + for layer_id, layer_modules in enumerate(self.lora_modules): for module_name, module in layer_modules.items(): if "qkv_proj" in module_name: module.set_lora_info( @@ -353,115 +334,94 @@ class LoRAManager: ), ) - def init_state(self): + def init_state( + self, + max_lora_rank: Optional[int] = None, + target_modules: Optional[Iterable[str]] = None, + lora_paths: Optional[Dict[str, LoRARef]] = None, + ): """ Initialize the internal (mutable) state of the LoRAManager. - These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically. + When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as + the target modules and max_lora_rank. """ - # Configs of all active LoRA adapters. + assert lora_paths or ( + max_lora_rank is not None and target_modules is not None + ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." + + self.init_lora_adapters(lora_paths) + self.init_lora_shapes( + max_lora_rank=max_lora_rank, + target_modules=target_modules, + ) + self.init_lora_weight_names() + self.init_lora_modules() + self.init_memory_pool() + + def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None): + # Configs of all active LoRA adapters, indexed by LoRA ID. self.configs: Dict[str, LoRAConfig] = {} - # LoRA adapter weights cached in CPU memory. + # LoRA adapter weights cached in CPU memory, indexed by LoRA ID. self.loras: Dict[str, LoRAAdapter] = {} - # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively. - self.lora_weight_names: Tuple[Set[str]] = (set(), set()) + # Mapping from LoRA ID to LoRARef object. + self.lora_refs: Dict[str, LoRARef] = {} - # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. - self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = { - i: {} for i in range(self.base_hf_config.num_hidden_layers) - } + if lora_paths: + for lora_ref in lora_paths.values(): + result = self.load_lora_adapter(lora_ref) + if not result.success: + raise RuntimeError( + f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}" + ) - # The LoRA memory pool that manages the GPU buffers for active LoRA weights. - # It is initialized lazily when the first LoRA adapter is loaded. - self.memory_pool: Optional[LoRAMemoryPool] = None + def init_lora_shapes( + self, + max_lora_rank: Optional[int] = None, + target_modules: Optional[Iterable[str]] = None, + ): + """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided.""" - def update_state_from_configs(self): - """ - Update the internal state of the LoRAManager based on the current `self.configs`. This method - should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded). - """ - - # Loads / unloads LoRA adapters based on the latest configs. - self.update_lora_adapters() - # Apply the latest LoRA configurations to the internal state for inferencing. - self.apply_lora_configs() - - def apply_lora_configs(self): - """ - Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing. - - Notes: - - Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as - we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer - LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in - early CY25H2. - """ - - if self.memory_pool is None: - # Infer max_lora_rank and target_modules if not explicitly specified in server args. - if self.target_modules is None: - self.target_modules = set() - for config in self.configs.values(): - self.target_modules.update(config.target_modules) - - if self.max_lora_rank is None: - self.max_lora_rank = max( - [x.hf_config["r"] for x in self.configs.values()], - default=0, - ) - - self.update_lora_weight_names() - self.update_lora_modules() - self.update_memory_buffers() + if target_modules is not None: + self.target_modules = set(target_modules) else: - # No-op if the memory pool can support the current LoRA configurations. - # TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target - # module is changed once FlashInfer backend is deprecated. - assert self.memory_pool.can_support(self.configs.values()), ( - "LoRA memory pool cannot support the current LoRA configuration. " - "This should never happen as we should have validated adapter compatibility. " - "Please create a Github issue to report.", + self.target_modules = set() + for config in self.configs.values(): + self.target_modules.update(config.target_modules) + + if max_lora_rank is not None: + self.max_lora_rank = max_lora_rank + else: + self.max_lora_rank = max( + [x.hf_config["r"] for x in self.configs.values()], + default=0, ) - def update_lora_weight_names(self): + def init_lora_weight_names(self): """ Add new LoRA weight names if needed based on the current `self.configs`. """ # Target lora weight names for lora_a and lora_b modules respectively. lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules) - self.lora_weight_names[0].update(lora_A) - self.lora_weight_names[1].update(lora_B) + self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B)) - def update_lora_adapters(self): + def load_lora_weights(self, lora_ref: LoRARef): """ - Update the LoRA adapters in CPU memory based on the current `self.configs`. - It loads any new adapters that are not already loaded, and unloads any adapters - that are no longer in `self.configs` (e.g., unloaded). + Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation. """ - - # Load new adapter weights to cpu - for name, config in self.configs.items(): - if name not in self.loras: - logger.info(f"Loading weight of LoRA adapter {name} from {config.path}") - lora_adapter = LoRAAdapter( - name, - config, - self.base_hf_config, - self.load_config, - self.lora_backend, - ) - lora_adapter.initialize_weights() - self.loras[name] = lora_adapter - - # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration. - for name in list(self.loras): - if name not in self.configs: - logger.info(f"Unloading LoRA adapter {name}") - del self.loras[name] + lora_adapter = LoRAAdapter( + lora_ref.lora_id, + self.configs[lora_ref.lora_id], + self.base_hf_config, + self.load_config, + self.lora_backend, + ) + lora_adapter.initialize_weights() + self.loras[lora_ref.lora_id] = lora_adapter # Additional checks for flashinfer backend # FIXME remove the restrictions after supporting multi-rank for flashinfer backend @@ -472,7 +432,7 @@ class LoRAManager: len(lora_dims) == 1 and len(scalings) == 1 ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " - def update_memory_buffers(self): + def init_memory_pool(self): """(Re)initialize the LoRA memory pool based on the current configurations.""" self.memory_pool = LoRAMemoryPool( base_hf_config=self.base_hf_config, @@ -490,7 +450,12 @@ class LoRAManager: replace_submodule(self.base_model, module_name, lora_module) return lora_module - def update_lora_modules(self): + def init_lora_modules(self): + # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. + self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [ + {} for _ in range(self.base_hf_config.num_hidden_layers) + ] + # Target module names of customized layers defined in python/sglang/srt/layers # e.g., {"qkv_proj", "o_proj"} customized_target_names = get_customized_names_from_hf_names( @@ -511,7 +476,6 @@ class LoRAManager: # The module should be converted if it is included in target_names if module_name.split(".")[-1] in customized_target_names: layer_id = get_layer_id(module_name) - if module_name not in self.lora_modules[layer_id]: - self.lora_modules[layer_id][module_name] = self.set_lora_module( - module_name, module - ) + self.lora_modules[layer_id][module_name] = self.set_lora_module( + module_name, module + ) diff --git a/python/sglang/srt/lora/lora_registry.py b/python/sglang/srt/lora/lora_registry.py new file mode 100644 index 000000000..b596c7371 --- /dev/null +++ b/python/sglang/srt/lora/lora_registry.py @@ -0,0 +1,124 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import asyncio +from dataclasses import dataclass, field, fields +from typing import Dict, List, Optional, Union +from uuid import uuid4 + + +@dataclass(frozen=True, slots=True) +class LoRARef: + """ + Reference record for a LoRA model. + + This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID + eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache + keys (e.g., radix cache). + """ + + lora_id: str = field(default_factory=lambda: uuid4().hex) + lora_name: Optional[str] = None + lora_path: Optional[str] = None + + def __post_init__(self): + if self.lora_id is None: + raise ValueError("lora_id cannot be None") + + def __str__(self) -> str: + parts = [ + f"{f.name}={value}" + for f in fields(self) + if (value := getattr(self, f.name)) is not None + ] + return f"{self.__class__.__name__}({', '.join(parts)})" + + +class LoRARegistry: + """ + The central registry to keep track of available LoRA adapters. + + TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided + to keep it in a separate PR to keep code review simple and to unblock the radix cache work. + """ + + def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None): + assert lora_paths is None or all( + isinstance(lora, LoRARef) for lora in lora_paths.values() + ), ( + "server_args.lora_paths should have been normalized to LoRARef objects during server initialization. " + "Please file an issue if you see this error." + ) + + # A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef. + self._registry: Dict[str, LoRARef] = dict(lora_paths or {}) + + async def register(self, lora_ref: LoRARef): + """ + Register a new LoRARef object in the registry. + + Args: + lora_ref (LoRARef): The LoRARef object to register. + """ + if lora_ref.lora_name in self._registry: + raise ValueError( + f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}" + ) + self._registry[lora_ref.lora_name] = lora_ref + + async def unregister(self, lora_name: str) -> str: + """ + Unregister a LoRARef object from the registry and returns the removed LoRA ID. + + Args: + lora_name (str): The name of the LoRA model to unregister. + """ + lora_ref = self._registry.get(lora_name, None) + if lora_ref is None: + raise ValueError( + f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}" + ) + del self._registry[lora_name] + + return lora_ref.lora_id + + async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]: + """ + Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters + by incrementing its counter. + + TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters. + """ + + async def _acquire_single(name: str) -> str: + lora_ref = self._registry.get(name, None) + if lora_ref is None: + raise ValueError( + f"The following requested LoRA adapters are not loaded: {name}\n" + f"Loaded adapters: {self._registry.keys()}." + ) + # await self._counters[lora_ref.lora_id].increment() + return lora_ref.lora_id + + if isinstance(lora_name, str): + lora_id = await _acquire_single(lora_name) + return lora_id + elif isinstance(lora_name, list): + lora_ids = await asyncio.gather( + *[_acquire_single(name) for name in lora_name] + ) + return lora_ids + else: + raise TypeError("lora_name must be either a string or a list of strings.") diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 1b36cac5e..ae856246d 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -153,7 +153,7 @@ class LoRAMemoryPool: self, cur_uids: Set[Optional[str]], lora_adapters: Dict[str, LoRAAdapter], - lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], + lora_modules: List[Dict[str, BaseLayerWithLoRA]], ): def get_available_buffer_slot(): for buffer_id in range(self.max_loras_per_batch): @@ -186,7 +186,7 @@ class LoRAMemoryPool: uid: str, buffer_id: int, lora_adapter: LoRAAdapter, - lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], + lora_modules: List[Dict[str, BaseLayerWithLoRA]], ): def load_lora_weight_tensor( buffer_view: torch.Tensor, weight: Optional[torch.Tensor] diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 8e1d1075a..3d18e1af4 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,6 +22,7 @@ from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.sampling.sampling_params import SamplingParams @@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput: lora_name: str # The path of loading. lora_path: str + # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. + lora_id: Optional[str] = None + + def to_ref(self) -> LoRARef: + return LoRARef( + lora_id=self.lora_id, + lora_name=self.lora_name, + lora_path=self.lora_path, + ) @dataclass class UnloadLoRAAdapterReqInput: # The name of lora module to unload. lora_name: str + # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. + lora_id: Optional[str] = None + + def to_ref(self) -> LoRARef: + return LoRARef( + lora_id=self.lora_id, + lora_name=self.lora_name, + ) @dataclass class LoRAUpdateResult: success: bool error_message: Optional[str] = None - loaded_adapters: Dict[str, str] = field(default_factory=dict) + loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict) LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e6dd80d71..c3b5fc2e8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -247,7 +247,7 @@ class Scheduler( self.pp_size = server_args.pp_size self.dp_size = server_args.dp_size self.schedule_policy = server_args.schedule_policy - self.lora_paths = server_args.lora_paths + self.enable_lora = server_args.enable_lora self.max_loras_per_batch = server_args.max_loras_per_batch self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init @@ -1706,13 +1706,13 @@ class Scheduler( self.chunked_req.init_next_round_input() self.chunked_req = adder.add_chunked_req(self.chunked_req) - if self.lora_paths: + if self.enable_lora: lora_set = set([req.lora_path for req in self.running_batch.reqs]) # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( - self.lora_paths + self.enable_lora and len( lora_set | set([req.lora_path for req in adder.can_run_list]) @@ -2466,12 +2466,6 @@ class Scheduler( """In-place loading a new lora adapter from disk or huggingface.""" result = self.tp_worker.load_lora_adapter(recv_req) - - if result.success: - flush_cache_success = self.flush_cache() - assert flush_cache_success, "Cache flush failed after loading lora adapter." - else: - logger.error(result.error_message) return result def unload_lora_adapter( @@ -2480,14 +2474,6 @@ class Scheduler( """Unload the lora adapter.""" result = self.tp_worker.unload_lora_adapter(recv_req) - - if result.success: - flush_cache_success = self.flush_cache() - assert ( - flush_cache_success - ), "Cache flush failed after unloading LoRA weights" - else: - logger.error(result.error_message) return result def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 631d23f17..0f65fa925 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import ( get_tokenizer, get_tokenizer_from_processor, ) +from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -242,11 +243,11 @@ class TokenizerManager: revision=server_args.revision, ) - # Initialize loaded loRA adapters with the initial lora paths in the server_args. - # This list will be updated when new LoRA adapters are loaded or unloaded dynamically. - self.loaded_lora_adapters: Dict[str, str] = dict( - self.server_args.lora_paths or {} - ) + # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`. + # The registry dynamically updates as adapters are loaded / unloaded during runtime. It + # serves as the source of truth for available adapters and maps user-friendly LoRA names + # to internally used unique LoRA IDs. + self.lora_registry = LoRARegistry(self.server_args.lora_paths or {}) # Store states self.no_create_loop = False @@ -523,6 +524,10 @@ class TokenizerManager: else: mm_inputs = None + if self.server_args.enable_lora and obj.lora_path: + # Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs. + obj.lora_path = await self.lora_registry.acquire(obj.lora_path) + self._validate_one_request(obj, input_ids) return self._create_tokenized_object( obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids @@ -574,8 +579,6 @@ class TokenizerManager: "The server is not configured to enable custom logit processor. " "Please set `--enable-custom-logits-processor` to enable this feature." ) - if self.server_args.enable_lora and obj.lora_path: - self._validate_lora_adapters(obj) def _validate_input_ids_in_vocab( self, input_ids: List[int], vocab_size: int @@ -689,21 +692,6 @@ class TokenizerManager: "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." ) - def _validate_lora_adapters(self, obj: GenerateReqInput): - """Validate that the requested LoRA adapters are loaded.""" - requested_adapters = ( - set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path} - ) - loaded_adapters = ( - self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set() - ) - unloaded_adapters = requested_adapters - loaded_adapters - if unloaded_adapters: - raise ValueError( - f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n" - f"Loaded adapters: {loaded_adapters}." - ) - def _send_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -1054,8 +1042,18 @@ class TokenizerManager: ) async with self.model_update_lock.writer_lock: + # Generate new uniquely identifiable LoRARef object. + new_adapter = LoRARef( + lora_name=obj.lora_name, + lora_path=obj.lora_path, + ) + + # Register the new adapter in the registry. + obj.lora_id = new_adapter.lora_id result = (await self.update_lora_adapter_communicator(obj))[0] - self.loaded_lora_adapters = result.loaded_adapters + if result.success: + await self.lora_registry.register(new_adapter) + return result async def unload_lora_adapter( @@ -1069,6 +1067,10 @@ class TokenizerManager: "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." ) + assert ( + obj.lora_name is not None + ), "lora_name must be provided to unload LoRA adapter" + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. assert ( @@ -1080,8 +1082,9 @@ class TokenizerManager: ) async with self.model_update_lock.writer_lock: + obj.lora_id = await self.lora_registry.unregister(obj.lora_name) result = (await self.update_lora_adapter_communicator(obj))[0] - self.loaded_lora_adapters = result.loaded_adapters + return result async def get_weights_by_name( @@ -1309,7 +1312,7 @@ class TokenizerManager: filename = os.path.join( self.crash_dump_folder, os.getenv("HOSTNAME", None), - f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl', + f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl", ) os.makedirs(os.path.dirname(filename), exist_ok=True) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ff20ea01e..d0939ffca 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -293,11 +293,9 @@ class TpModelWorker: return parameter def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): - result = self.model_runner.load_lora_adapter( - recv_req.lora_name, recv_req.lora_path - ) + result = self.model_runner.load_lora_adapter(recv_req.to_ref()) return result def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): - result = self.model_runner.unload_lora_adapter(recv_req.lora_name) + result = self.model_runner.unload_lora_adapter(recv_req.to_ref()) return result diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4f0b1d64c..9e6d14aac 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.lora.lora_manager import LoRAManager +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.managers.schedule_batch import ( GLOBAL_SERVER_ARGS_KEYS, global_server_args_dict, @@ -890,44 +891,38 @@ class ModelRunner: tp_rank=self.tp_rank, max_lora_rank=self.server_args.max_lora_rank, target_modules=self.server_args.lora_target_modules, + lora_paths=self.server_args.lora_paths, ) - result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {}) - if result.success: - logger.info( - f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}" - ) - else: - raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}") - def load_lora_adapter(self, lora_name: str, lora_path: str): + def load_lora_adapter(self, lora_ref: LoRARef): """Load a new lora adapter from disk or huggingface.""" logger.info( - f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. " + f"LoRA adapter loading starts: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - result = self.lora_manager.load_lora_adapter(lora_name, lora_path) + result = self.lora_manager.load_lora_adapter(lora_ref) logger.info( - f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. " + f"LoRA adapter loading completes: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) return result - def unload_lora_adapter(self, lora_name: str): + def unload_lora_adapter(self, lora_ref: LoRARef): """Unload a lora adapter that was previously loaded during initialization or dynamic loading.""" logger.info( - f"LoRA adapter unloading starts: name={lora_name}. " + f"LoRA adapter unloading starts: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - result = self.lora_manager.unload_lora_adapter(lora_name) + result = self.lora_manager.unload_lora_adapter(lora_ref) logger.info( - f"LoRA adapter unloading completes: name={lora_name}. " + f"LoRA adapter unloading completes: {lora_ref}. " f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 400a1bf99..1625f2c3a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,10 +20,10 @@ import logging import os import random import tempfile -from token import OP from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( LORA_TARGET_ALL_MODULES, @@ -145,7 +145,7 @@ class ServerArgs: enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None - lora_paths: Optional[Union[dict[str, str], List[str]]] = None + lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1843,9 +1843,24 @@ class ServerArgs: for lora_path in lora_paths: if "=" in lora_path: name, path = lora_path.split("=", 1) - self.lora_paths[name] = path + self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path) else: - self.lora_paths[lora_path] = lora_path + self.lora_paths[lora_path] = LoRARef( + lora_name=lora_path, + lora_path=lora_path, + ) + elif isinstance(self.lora_paths, dict): + self.lora_paths = { + k: LoRARef(lora_name=k, lora_path=v) + for k, v in self.lora_paths.items() + } + elif self.lora_paths is None: + self.lora_paths = {} + else: + raise ValueError( + f"Invalid type for --lora-paths: {type(self.lora_paths)}. " + "Expected a list or a dictionary." + ) # Expand target modules if self.lora_target_modules: diff --git a/test/srt/models/lora/test_lora_eviction.py b/test/srt/models/lora/test_lora_eviction.py index e74af0a0e..b352da2d5 100644 --- a/test/srt/models/lora/test_lora_eviction.py +++ b/test/srt/models/lora/test_lora_eviction.py @@ -12,6 +12,7 @@ # limitations under the License. # ============================================================================== +import contextlib import multiprocessing as mp import unittest from typing import Dict, List, Tuple @@ -39,6 +40,16 @@ ADAPTERS = [ BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" +@contextlib.contextmanager +def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str): + """A context manager to load and automatically unload a LoRA adapter.""" + try: + runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path) + yield + finally: + runner.unload_lora_adapter(lora_name=lora_name) + + class TestLoRAEviction(CustomTestCase): def test_lora_eviction_with_different_target_modules(self): """ @@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase): self._run_test(ADAPTERS, output_history, reverse=False) self._run_test(ADAPTERS, output_history, reverse=True) + def test_lora_eviction_with_reused_lora_name(self): + """ + Test LoRA eviction with reused LoRA names. + + This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior + works correctly when reusing LoRA names. + """ + output_history = {} + self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1) + self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1) + def _run_test( self, lora_paths: List[str], output_history: Dict[Tuple[str, str], str], - reverse: bool, + reverse: bool = False, repeat: int = 2, + reuse_lora_name: bool = False, ): + REUSED_LORA_NAME = "lora" max_new_tokens = 256 backend = "triton" torch_dtype = torch.float16 base_path = BASE_MODEL assert len(lora_paths) >= 2 + initial_lora_paths = lora_paths if not reuse_lora_name else None # Initialize runners with SRTRunner( base_path, torch_dtype=torch_dtype, model_type="generation", - lora_paths=lora_paths, + lora_paths=initial_lora_paths, max_loras_per_batch=1, lora_backend=backend, disable_radix_cache=True, + enable_lora=True, + max_lora_rank=256, + lora_target_modules=["all"], ) as srt_runner: adapter_sequence = lora_paths if not reverse else lora_paths[::-1] for i in range(repeat): - for j, adapter in enumerate(adapter_sequence): + for j, lora_path in enumerate(adapter_sequence): print( - f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---" + f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---" ) - for prompt in PROMPTS: - print("\nprompt:\n", prompt) - srt_outputs = srt_runner.forward( - [prompt], - max_new_tokens=max_new_tokens, - lora_paths=[adapter], - ) - output = srt_outputs.output_strs[0].strip() - print("\noutput:\n", output) - prev_output = output_history.get((adapter, prompt)) - if prev_output is not None: - self.assertEqual( - prev_output, - output, - f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.", + lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path + context = ( + dynamically_loaded_adapter(srt_runner, lora_path, lora_name) + if reuse_lora_name + else contextlib.nullcontext() + ) + with context: + for prompt in PROMPTS: + print("\nprompt:\n", prompt) + srt_outputs = srt_runner.forward( + [prompt], + max_new_tokens=max_new_tokens, + lora_paths=[lora_name], ) - else: - output_history[(adapter, prompt)] = output + output = srt_outputs.output_strs[0].strip() + print("\noutput:\n", output) + + prev_output = output_history.get((lora_path, prompt)) + if prev_output is not None: + self.assertEqual( + prev_output, + output, + f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.", + ) + else: + output_history[(lora_path, prompt)] = output if __name__ == "__main__": diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 0e62760ab..6a96cf598 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -14,7 +14,7 @@ class TestFile: suites = { "per-commit": [ TestFile("models/lora/test_lora.py", 200), - TestFile("models/lora/test_lora_eviction.py", 120), + TestFile("models/lora/test_lora_eviction.py", 200), TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250),