From 8675bdf2461550a19192c334a46d55c58f314dbf Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sun, 3 Aug 2025 00:02:23 -0700 Subject: [PATCH] Support limiting max loaded loras in CPU. (#8650) --- docs/backend/lora.ipynb | 2 + docs/backend/server_arguments.md | 1 + python/sglang/srt/lora/lora_registry.py | 7 + python/sglang/srt/managers/io_struct.py | 2 +- .../sglang/srt/managers/tokenizer_manager.py | 132 ++++++++++-------- python/sglang/srt/server_args.py | 20 +++ python/sglang/test/runners.py | 2 + test/srt/models/lora/test_lora_update.py | 55 +++++++- 8 files changed, 163 insertions(+), 58 deletions(-) diff --git a/docs/backend/lora.ipynb b/docs/backend/lora.ipynb index 8626d3e71..4967b9c75 100644 --- a/docs/backend/lora.ipynb +++ b/docs/backend/lora.ipynb @@ -33,6 +33,8 @@ "\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", "\n", + "* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n", + "\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 047458123..bff9dbcdc 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -181,6 +181,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | +| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | ## Kernel backend diff --git a/python/sglang/srt/lora/lora_registry.py b/python/sglang/srt/lora/lora_registry.py index c063fefae..bb2fc5659 100644 --- a/python/sglang/srt/lora/lora_registry.py +++ b/python/sglang/srt/lora/lora_registry.py @@ -186,3 +186,10 @@ class LoRARegistry: self._registry[lora_ref.lora_name] = lora_ref self._counters[lora_ref.lora_id] = ConcurrentCounter() return lora_ref + + @property + def num_registered_loras(self) -> int: + """ + Returns the total number of LoRA adapters currently registered. + """ + return len(self._registry) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index c8d325f9e..2b5f19c71 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1097,7 +1097,7 @@ class UnloadLoRAAdapterReqInput: class LoRAUpdateResult: success: bool error_message: Optional[str] = None - loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict) + loaded_adapters: Optional[Dict[str, LoRARef]] = None LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 76a31e334..89326bf06 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1084,76 +1084,98 @@ class TokenizerManager: _: Optional[fastapi.Request] = None, ) -> LoadLoRAAdapterReqOutput: self.auto_create_handle_loop() - if not self.server_args.enable_lora: - raise ValueError( - "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + + try: + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) + + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works + # with dp_size > 1. + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for dynamic lora loading" + logger.info( + "Start load Lora adapter. Lora name=%s, path=%s", + obj.lora_name, + obj.lora_path, ) - # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works - # with dp_size > 1. - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for dynamic lora loading" - logger.info( - "Start load Lora adapter. Lora name=%s, path=%s", - obj.lora_name, - obj.lora_path, - ) + async with self.lora_update_lock: + if ( + self.server_args.max_loaded_loras is not None + and self.lora_registry.num_registered_loras + >= self.server_args.max_loaded_loras + ): + raise ValueError( + f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. " + f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. " + "Please unload some LoRA adapters before loading new ones." + ) - async with self.lora_update_lock: - # Generate new uniquely identifiable LoRARef object. - new_adapter = LoRARef( - lora_name=obj.lora_name, - lora_path=obj.lora_path, + # Generate new uniquely identifiable LoRARef object. + new_adapter = LoRARef( + lora_name=obj.lora_name, + lora_path=obj.lora_path, + ) + + # Trigger the actual loading operation at the backend processes. + obj.lora_id = new_adapter.lora_id + result = (await self.update_lora_adapter_communicator(obj))[0] + + # Register the LoRA adapter only after loading is successful. + if result.success: + await self.lora_registry.register(new_adapter) + + return result + except ValueError as e: + return LoadLoRAAdapterReqOutput( + success=False, + error_message=str(e), ) - # Trigger the actual loading operation at the backend processes. - obj.lora_id = new_adapter.lora_id - result = (await self.update_lora_adapter_communicator(obj))[0] - - # Register the LoRA adapter only after loading is successful. - if result.success: - await self.lora_registry.register(new_adapter) - - return result - async def unload_lora_adapter( self, obj: UnloadLoRAAdapterReqInput, _: Optional[fastapi.Request] = None, ) -> UnloadLoRAAdapterReqOutput: self.auto_create_handle_loop() - if not self.server_args.enable_lora: - raise ValueError( - "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + + try: + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) + + assert ( + obj.lora_name is not None + ), "lora_name must be provided to unload LoRA adapter" + + # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works + # with dp_size > 1. + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for dynamic lora loading" + logger.info( + "Start unload Lora adapter. Lora name=%s", + obj.lora_name, ) - assert ( - obj.lora_name is not None - ), "lora_name must be provided to unload LoRA adapter" + async with self.lora_update_lock: + # Unregister the LoRA adapter from the registry to stop new requests for this adapter + # from being started. + lora_id = await self.lora_registry.unregister(obj.lora_name) + obj.lora_id = lora_id - # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works - # with dp_size > 1. - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for dynamic lora loading" - logger.info( - "Start unload Lora adapter. Lora name=%s", - obj.lora_name, - ) + # Initiate the actual unloading operation at the backend processes only after all + # ongoing requests using this LoRA adapter are finished. + await self.lora_registry.wait_for_unload(lora_id) + result = (await self.update_lora_adapter_communicator(obj))[0] - async with self.lora_update_lock: - # Unregister the LoRA adapter from the registry to stop new requests for this adapter - # from being started. - lora_id = await self.lora_registry.unregister(obj.lora_name) - obj.lora_id = lora_id - - # Initiate the actual unloading operation at the backend processes only after all - # ongoing requests using this LoRA adapter are finished. - await self.lora_registry.wait_for_unload(lora_id) - result = (await self.update_lora_adapter_communicator(obj))[0] - - return result + return result + except ValueError as e: + return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e)) async def get_weights_by_name( self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3b52f5801..7f3fd88b1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -149,6 +149,7 @@ class ServerArgs: max_lora_rank: Optional[int] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None + max_loaded_loras: Optional[int] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1237,6 +1238,12 @@ class ServerArgs: default=8, help="Maximum number of adapters for a running batch, include base-only request.", ) + parser.add_argument( + "--max-loaded-loras", + type=int, + default=ServerArgs.max_loaded_loras, + help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.", + ) parser.add_argument( "--lora-backend", type=str, @@ -2008,6 +2015,19 @@ class ServerArgs: self.max_lora_rank and self.lora_target_modules ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." + # Validate max_loaded_loras + if self.max_loaded_loras is not None: + assert self.max_loaded_loras >= self.max_loras_per_batch, ( + "max_loaded_loras should be greater than or equal to max_loras_per_batch. " + f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}" + ) + assert ( + not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras + ), ( + "The number of LoRA paths should not exceed max_loaded_loras. " + f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}" + ) + def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): larger_tp = max(decode_tp, prefill_tp) smaller_tp = min(decode_tp, prefill_tp) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 129b4ebb9..ee49584a0 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -514,6 +514,7 @@ class SRTRunner: max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, enable_lora: Optional[bool] = None, + max_loaded_loras: Optional[int] = None, ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -556,6 +557,7 @@ class SRTRunner: max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, enable_lora=enable_lora, + max_loaded_loras=max_loaded_loras, **spec_kwargs, ) diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py index 83392b924..ef5a4c845 100644 --- a/test/srt/models/lora/test_lora_update.py +++ b/test/srt/models/lora/test_lora_update.py @@ -70,6 +70,7 @@ class TestCase: max_lora_rank: Optional[int] = None lora_target_modules: Optional[List] = None max_new_tokens: int = 32 + max_loaded_loras: Optional[int] = None def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: @@ -559,7 +560,43 @@ MAX_LORA_RANK_TESTS = [ ], ), ] -ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS +MAX_LOADED_LORAS_TESTS = [ + TestCase( + description="Test max_loaded_loras limit", + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=2, + max_loaded_loras=2, + all_adapters=[ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ], + initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], + op_sequence=[ + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + expected_error="Maximum number of loaded LoRA adapters", + ), + Operation( + type=OperationType.UNLOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), + ], + ), +] + +ALL_TESTS = ( + BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS +) class LoRAUpdateTestSessionMode(Enum): @@ -579,6 +616,7 @@ class LoRAUpdateTestSessionBase: model_path: str, lora_paths: list[str], max_loras_per_batch: int, + max_loaded_loras: Optional[int] = None, max_lora_rank: Optional[int], enable_lora: Optional[bool] = None, lora_target_modules: Optional[List[str]] = None, @@ -592,6 +630,7 @@ class LoRAUpdateTestSessionBase: self.max_lora_rank = max_lora_rank self.lora_target_modules = lora_target_modules self.max_loras_per_batch = max_loras_per_batch + self.max_loaded_loras = max_loaded_loras self.lora_backend = lora_backend self.disable_cuda_graph = disable_cuda_graph self.cuda_graph_max_bs = cuda_graph_max_bs @@ -654,6 +693,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): torch_dtype=torch.float16, mem_fraction_static=MEM_FRACTION_STATIC, max_loras_per_batch=self.max_loras_per_batch, + max_loaded_loras=self.max_loaded_loras, disable_cuda_graph=self.disable_cuda_graph, cuda_graph_max_bs=self.cuda_graph_max_bs, disable_radix_cache=True, @@ -774,6 +814,8 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): other_args.extend(["--max-lora-rank", str(self.max_lora_rank)]) if self.lora_target_modules is not None: other_args.extend(["--lora-target-modules"] + self.lora_target_modules) + if self.max_loaded_loras is not None: + other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)]) # launch external server self.handle = popen_launch_server( @@ -898,8 +940,9 @@ class TestLoRADynamicUpdate(CustomTestCase): mode: LoRAUpdateTestSessionMode, base: str, initial_adapters: List[str], - max_loras_per_batch: int, op_sequence: List[Operation], + max_loras_per_batch: int, + max_loaded_loras: Optional[int] = None, enable_lora: Optional[bool] = None, max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, @@ -917,6 +960,7 @@ class TestLoRADynamicUpdate(CustomTestCase): model_path=base, lora_paths=initial_adapters, max_loras_per_batch=max_loras_per_batch, + max_loaded_loras=max_loaded_loras, max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, enable_lora=enable_lora, @@ -972,6 +1016,7 @@ class TestLoRADynamicUpdate(CustomTestCase): enable_lora=test_case.enable_lora, base=test_case.base, max_loras_per_batch=test_case.max_loras_per_batch, + max_loaded_loras=test_case.max_loaded_loras, op_sequence=test_case.op_sequence, max_new_tokens=test_case.max_new_tokens, max_lora_rank=test_case.max_lora_rank, @@ -985,6 +1030,12 @@ class TestLoRADynamicUpdate(CustomTestCase): if x.type == OperationType.FORWARD and x.expected_error is None ] + if not forward_ops: + print( + f"No forward operations found in test case {case_idx}. Skipping static pass." + ) + continue + print("=" * 100) print(f"\n--- Running static pass with {len(forward_ops)} operations ---") static_output = self._run_operation_sequence(