diff --git a/docs/advanced_features/lora.ipynb b/docs/advanced_features/lora.ipynb index 708508134..cccf9d749 100644 --- a/docs/advanced_features/lora.ipynb +++ b/docs/advanced_features/lora.ipynb @@ -29,7 +29,7 @@ "\n", "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n", "\n", - "* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n", + "* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: | = | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n", "\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", "\n", @@ -372,6 +372,15 @@ "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -387,7 +396,40 @@ "\n", "This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n", "\n", - "In the example below, we unload `lora1` and reload it as a `pinned` adapter:" + "In the example below, we start a server with `lora1` loaded as pinned, `lora2` and `lora3` loaded as regular (unpinned) adapters. Please note that, we intentionally specify `lora2` and `lora3` in two different formats to demonstrate that both are supported." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", + " --cuda-graph-max-bs 8 \\\n", + " --max-loras-per-batch 3 --lora-backend triton \\\n", + " --max-lora-rank 256 \\\n", + " --lora-target-modules all \\\n", + " --lora-paths \\\n", + " {\"lora_name\":\"lora0\",\"lora_path\":\"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\"pinned\":true} \\\n", + " {\"lora_name\":\"lora1\",\"lora_path\":\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\"} \\\n", + " lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora\n", + " \"\"\"\n", + ")\n", + "\n", + "\n", + "url = f\"http://127.0.0.1:{port}\"\n", + "wait_for_server(url)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also specify adapter as pinned during dynamic adapter loading. In the example below, we reload `lora2` as pinned adapter:" ] }, { @@ -407,7 +449,7 @@ " url + \"/load_lora_adapter\",\n", " json={\n", " \"lora_name\": \"lora1\",\n", - " \"lora_path\": lora1,\n", + " \"lora_path\": \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\",\n", " \"pinned\": True, # Pin the adapter to GPU\n", " },\n", ")" @@ -417,7 +459,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Verify that the result is identical as before:" + "Verify that the results are expected:" ] }, { @@ -431,17 +473,19 @@ " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", " \"List 3 countries and their capitals.\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", - " \"lora_path\": [\"lora0\", \"lora1\"],\n", + " \"lora_path\": [\"lora0\", \"lora1\", \"lora2\"],\n", "}\n", "response = requests.post(\n", " url + \"/generate\",\n", " json=json_data,\n", ")\n", - "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", - "print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")" + "print(f\"Output from lora0 (pinned): \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")\n", + "print(f\"Output from lora2 (not pinned): \\n{response.json()[2]['text']}\\n\")" ] }, { diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index c63b8a604..2fedb8d53 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -179,7 +179,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False | | `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | | `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None | -| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | +| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: | = | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index c2a3eaabc..ef1120d1e 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -55,7 +55,7 @@ class LoRAManager: tp_rank: int = 0, max_lora_rank: Optional[int] = None, target_modules: Optional[Iterable[str]] = None, - lora_paths: Optional[Dict[str, LoRARef]] = None, + lora_paths: Optional[List[LoRARef]] = None, ): self.base_model: torch.nn.Module = base_model self.base_hf_config: AutoConfig = base_hf_config @@ -370,7 +370,7 @@ class LoRAManager: self, max_lora_rank: Optional[int] = None, target_modules: Optional[Iterable[str]] = None, - lora_paths: Optional[Dict[str, LoRARef]] = None, + lora_paths: Optional[List[LoRARef]] = None, ): """ Initialize the internal (mutable) state of the LoRAManager. @@ -392,7 +392,7 @@ class LoRAManager: self.init_memory_pool() self.update_lora_info() - def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None): + def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None): # Configs of all active LoRA adapters, indexed by LoRA ID. self.configs: Dict[str, LoRAConfig] = {} @@ -406,7 +406,7 @@ class LoRAManager: self.num_pinned_loras: int = 0 if lora_paths: - for lora_ref in lora_paths.values(): + for lora_ref in lora_paths: result = self.load_lora_adapter(lora_ref) if not result.success: raise RuntimeError( diff --git a/python/sglang/srt/lora/lora_registry.py b/python/sglang/srt/lora/lora_registry.py index 535ab47b4..51d2b0e66 100644 --- a/python/sglang/srt/lora/lora_registry.py +++ b/python/sglang/srt/lora/lora_registry.py @@ -59,9 +59,9 @@ class LoRARegistry: update / eventual consistency model between the tokenizer manager process and the scheduler processes. """ - def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None): + def __init__(self, lora_paths: Optional[List[LoRARef]] = None): assert lora_paths is None or all( - isinstance(lora, LoRARef) for lora in lora_paths.values() + isinstance(lora, LoRARef) for lora in lora_paths ), ( "server_args.lora_paths should have been normalized to LoRARef objects during server initialization. " "Please file an issue if you see this error." @@ -78,7 +78,7 @@ class LoRARegistry: # Initialize the registry with provided LoRA paths, if present. if lora_paths: - for lora_ref in lora_paths.values(): + for lora_ref in lora_paths: self._register_adapter(lora_ref) async def register(self, lora_ref: LoRARef): diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b0416a065..adfdd0541 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -298,7 +298,7 @@ class TokenizerManager: # The registry dynamically updates as adapters are loaded / unloaded during runtime. It # serves as the source of truth for available adapters and maps user-friendly LoRA names # to internally used unique LoRA IDs. - self.lora_registry = LoRARegistry(self.server_args.lora_paths or {}) + self.lora_registry = LoRARegistry(self.server_args.lora_paths) # Lock to serialize LoRA update operations. # Please note that, unlike `model_update_lock`, this does not block inference, allowing # LoRA updates and inference to overlap. diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b6a98e05f..36606e97a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -153,7 +153,9 @@ class ServerArgs: enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None - lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None + lora_paths: Optional[ + Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]] + ] = None max_loaded_loras: Optional[int] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1319,7 +1321,7 @@ class ServerArgs: nargs="*", default=None, action=LoRAPathAction, - help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.", + help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: | = | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}', ) parser.add_argument( "--max-loras-per-batch", @@ -2086,28 +2088,42 @@ class ServerArgs: ) if self.enable_lora: - # Normalize lora_paths to a dictionary if it is a list. - # TODO (lifuhuang): support specifying pinned adapters in server_args. if isinstance(self.lora_paths, list): lora_paths = self.lora_paths - self.lora_paths = {} + self.lora_paths = [] for lora_path in lora_paths: - if "=" in lora_path: - name, path = lora_path.split("=", 1) - self.lora_paths[name] = LoRARef( - lora_name=name, lora_path=path, pinned=False + if isinstance(lora_path, str): + if "=" in lora_path: + name, path = lora_path.split("=", 1) + lora_ref = LoRARef( + lora_name=name, lora_path=path, pinned=False + ) + else: + lora_ref = LoRARef( + lora_name=lora_path, lora_path=lora_path, pinned=False + ) + elif isinstance(lora_path, dict): + assert ( + "lora_name" in lora_path and "lora_path" in lora_path + ), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}" + lora_ref = LoRARef( + lora_name=lora_path["lora_name"], + lora_path=lora_path["lora_path"], + pinned=lora_path.get("pinned", False), ) else: - self.lora_paths[lora_path] = LoRARef( - lora_name=lora_path, lora_path=lora_path, pinned=False + raise ValueError( + f"Invalid type for item in --lora-paths list: {type(lora_path)}. " + "Expected a string or a dictionary." ) + self.lora_paths.append(lora_ref) elif isinstance(self.lora_paths, dict): - self.lora_paths = { - k: LoRARef(lora_name=k, lora_path=v, pinned=False) + self.lora_paths = [ + LoRARef(lora_name=k, lora_path=v, pinned=False) for k, v in self.lora_paths.items() - } + ] elif self.lora_paths is None: - self.lora_paths = {} + self.lora_paths = [] else: raise ValueError( f"Invalid type for --lora-paths: {type(self.lora_paths)}. " @@ -2134,9 +2150,7 @@ class ServerArgs: "max_loaded_loras should be greater than or equal to max_loras_per_batch. " f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}" ) - assert ( - not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras - ), ( + assert len(self.lora_paths) <= self.max_loaded_loras, ( "The number of LoRA paths should not exceed max_loaded_loras. " f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}" ) @@ -2357,13 +2371,22 @@ class PortArgs: class LoRAPathAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, {}) - for lora_path in values: - if "=" in lora_path: - name, path = lora_path.split("=", 1) - getattr(namespace, self.dest)[name] = path - else: - getattr(namespace, self.dest)[lora_path] = lora_path + lora_paths = [] + if values: + assert isinstance(values, list), "Expected a list of LoRA paths." + for lora_path in values: + lora_path = lora_path.strip() + if lora_path.startswith("{") and lora_path.endswith("}"): + obj = json.loads(lora_path) + assert "lora_path" in obj and "lora_name" in obj, ( + f"{repr(lora_path)} looks like a JSON str, " + "but it does not contain 'lora_name' and 'lora_path' keys." + ) + lora_paths.append(obj) + else: + lora_paths.append(lora_path) + + setattr(namespace, self.dest, lora_paths) class DeprecatedAction(argparse.Action): diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 248ba7285..96081b2c3 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -491,7 +491,7 @@ class SRTRunner: tp_size: int = 1, model_impl: str = "auto", port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, - lora_paths: List[str] = None, + lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None, max_loras_per_batch: int = 4, attention_backend: Optional[str] = None, prefill_attention_backend: Optional[str] = None, diff --git a/test/srt/lora/test_lora_update.py b/test/srt/lora/test_lora_update.py index e33fccc02..3c01858c7 100644 --- a/test/srt/lora/test_lora_update.py +++ b/test/srt/lora/test_lora_update.py @@ -12,6 +12,7 @@ # limitations under the License. # ============================================================================== +import json import multiprocessing as mp import unittest from dataclasses import dataclass @@ -89,8 +90,35 @@ BASIC_TESTS = [ "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", "pbevan11/llama-3.1-8b-ocr-correction", ], - initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], + initial_adapters=[ + # Testing 3 supported lora-path formats. + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + { + "lora_name": "pbevan11/llama-3.1-8b-ocr-correction", + "lora_path": "pbevan11/llama-3.1-8b-ocr-correction", + "pinned": False, + }, + ], op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ] + ), + ), + Operation( + type=OperationType.UNLOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.UNLOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), Operation( type=OperationType.FORWARD, data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), @@ -147,6 +175,10 @@ BASIC_TESTS = [ type=OperationType.UNLOAD, data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", ), + Operation( + type=OperationType.UNLOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), Operation( type=OperationType.FORWARD, data=create_batch_data( @@ -157,18 +189,12 @@ BASIC_TESTS = [ Operation( type=OperationType.FORWARD, data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), - ), - Operation( - type=OperationType.LOAD, - data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + expected_error="not loaded", ), Operation( type=OperationType.FORWARD, data=create_batch_data( - [ - "philschmid/code-llama-3-1-8b-text-to-sql-lora", - "pbevan11/llama-3.1-8b-ocr-correction", - ] + None, ), ), ], @@ -705,7 +731,7 @@ class LoRAUpdateTestSessionBase: *, testcase: Optional[TestCase], model_path: str, - lora_paths: list[str], + lora_paths: List[Union[str, dict]], max_loras_per_batch: int, max_loaded_loras: Optional[int] = None, max_lora_rank: Optional[int], @@ -727,7 +753,17 @@ class LoRAUpdateTestSessionBase: self.cuda_graph_max_bs = cuda_graph_max_bs self.enable_lora = enable_lora - self.expected_adapters = set(lora_paths or []) + self.expected_adapters = set() + if self.lora_paths: + for adapter in self.lora_paths: + if isinstance(adapter, dict): + lora_name = adapter["lora_name"] + elif "=" in adapter: + lora_name = adapter.split("=")[0] + else: + lora_name = adapter + self.expected_adapters.add(lora_name) + self.handle = None # Will be set in __enter__ def __enter__(self): @@ -926,7 +962,11 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): if self.enable_lora: other_args.append("--enable-lora") if self.lora_paths: - other_args.extend(["--lora-paths"] + self.lora_paths) + other_args.append("--lora-paths") + for lora_path in self.lora_paths: + if isinstance(lora_path, dict): + lora_path = json.dumps(lora_path) + other_args.append(lora_path) if self.disable_cuda_graph: other_args.append("--disable-cuda-graph") if self.max_lora_rank is not None: @@ -1093,7 +1133,7 @@ class TestLoRADynamicUpdate(CustomTestCase): self, mode: LoRAUpdateTestSessionMode, base: str, - initial_adapters: List[str], + initial_adapters: List[Union[str, dict]], op_sequence: List[Operation], max_loras_per_batch: int, max_loaded_loras: Optional[int] = None,