From 4e3defe5a77e14d70ad4ebfb3115ce507789f6e9 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sat, 19 Jul 2025 15:38:09 -0700 Subject: [PATCH] Support start up LoRA server without initial adapters (#8019) --- docs/backend/lora.ipynb | 271 ++++++++---------- docs/backend/server_arguments.md | 3 +- python/sglang/srt/lora/lora_manager.py | 6 +- .../sglang/srt/managers/tokenizer_manager.py | 10 +- .../srt/model_executor/cuda_graph_runner.py | 11 +- .../srt/model_executor/forward_batch_info.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 8 +- python/sglang/srt/server_args.py | 74 +++-- python/sglang/srt/utils.py | 14 + python/sglang/test/runners.py | 2 + test/srt/models/lora/test_lora_update.py | 82 +++++- test/srt/run_suite.py | 2 +- 12 files changed, 290 insertions(+), 195 deletions(-) diff --git a/docs/backend/lora.ipynb b/docs/backend/lora.ipynb index 6c089b654..8626d3e71 100644 --- a/docs/backend/lora.ipynb +++ b/docs/backend/lora.ipynb @@ -27,6 +27,8 @@ "source": [ "The following server arguments are relevant for multi-LoRA serving:\n", "\n", + "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n", + "\n", "* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n", "\n", "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n", @@ -35,7 +37,7 @@ "\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "\n", - "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n", + "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n", "\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "\n", @@ -79,6 +81,7 @@ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --max-loras-per-batch 1 --lora-backend triton \\\n", " --disable-radix-cache\n", @@ -98,7 +101,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses the base model\n", @@ -137,6 +140,7 @@ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --enable-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", @@ -157,7 +161,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", @@ -191,147 +195,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Basic Usage\n", - "\n", "Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n", "\n", - "(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "server_process, port = launch_server_cmd(\n", - " \"\"\"\n", - " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - " --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n", - " --cuda-graph-max-bs 2 \\\n", - " --max-loras-per-batch 2 --lora-backend triton \\\n", - " --disable-radix-cache\n", - " \"\"\"\n", - ")\n", - "\n", - "url = f\"http://127.0.0.1:{port}\"\n", - "wait_for_server(url)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = requests.post(\n", - " url + \"/load_lora_adapter\",\n", - " json={\n", - " \"lora_name\": \"lora1\",\n", - " \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n", - " },\n", - ")\n", - "\n", - "if response.status_code == 200:\n", - " print(\"LoRA adapter loaded successfully.\", response.json())\n", - "else:\n", - " print(\"Failed to load LoRA adapter.\", response.json())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = requests.post(\n", - " url + \"/generate\",\n", - " json={\n", - " \"text\": [\n", - " \"List 3 countries and their capitals.\",\n", - " \"List 3 countries and their capitals.\",\n", - " ],\n", - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", - " \"lora_path\": [\"lora0\", \"lora1\"],\n", - " },\n", - ")\n", - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora1: {response.json()[1]['text']}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = requests.post(\n", - " url + \"/unload_lora_adapter\",\n", - " json={\n", - " \"lora_name\": \"lora0\",\n", - " },\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = requests.post(\n", - " url + \"/load_lora_adapter\",\n", - " json={\n", - " \"lora_name\": \"lora2\",\n", - " \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n", - " },\n", - ")\n", - "\n", - "if response.status_code == 200:\n", - " print(\"LoRA adapter loaded successfully.\", response.json())\n", - "else:\n", - " print(\"Failed to load LoRA adapter.\", response.json())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = requests.post(\n", - " url + \"/generate\",\n", - " json={\n", - " \"text\": [\n", - " \"List 3 countries and their capitals.\",\n", - " \"List 3 countries and their capitals.\",\n", - " ],\n", - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", - " \"lora_path\": [\"lora1\", \"lora2\"],\n", - " },\n", - ")\n", - "print(f\"Output from lora1: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora2: {response.json()[1]['text']}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "terminate_process(server_process)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Advanced: hosting adapters of different shapes\n", - "\n", - "In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n", - "\n", - "For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." + "When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." ] }, { @@ -342,19 +208,20 @@ "source": [ "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", + "lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\" # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", "\n", "\n", "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", "# We are adding it here just to demonstrate usage.\n", "server_process, port = launch_server_cmd(\n", - " f\"\"\"\n", + " \"\"\"\n", " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - " --lora-paths lora0={lora0} \\\n", + " --enable-lora \\\n", " --cuda-graph-max-bs 2 \\\n", " --max-loras-per-batch 2 --lora-backend triton \\\n", " --disable-radix-cache\n", - " --max-lora-rank 64\n", - " --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n", + " --max-lora-rank 256\n", + " --lora-target-modules all\n", " \"\"\"\n", ")\n", "\n", @@ -362,6 +229,40 @@ "wait_for_server(url)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load adapter lora0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/load_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora0\",\n", + " \"lora_path\": lora0,\n", + " },\n", + ")\n", + "\n", + "if response.status_code == 200:\n", + " print(\"LoRA adapter loaded successfully.\", response.json())\n", + "else:\n", + " print(\"Failed to load LoRA adapter.\", response.json())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load adapter lora1:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -382,6 +283,13 @@ " print(\"Failed to load LoRA adapter.\", response.json())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check inference output:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -392,7 +300,7 @@ "json_data = {\n", " \"text\": [\n", " \"List 3 countries and their capitals.\",\n", - " \"AI is a field of computer science focused on\",\n", + " \"List 3 countries and their capitals.\",\n", " ],\n", " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", " # The first input uses lora0, and the second input uses lora1\n", @@ -402,8 +310,73 @@ " url + \"/generate\",\n", " json=json_data,\n", ")\n", - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", - "print(f\"Output from lora1: {response.json()[1]['text']}\")" + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Unload lora0 and replace it with a different adapter:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/unload_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora0\",\n", + " },\n", + ")\n", + "\n", + "response = requests.post(\n", + " url + \"/load_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora0\",\n", + " \"lora_path\": lora0_new,\n", + " },\n", + ")\n", + "\n", + "if response.status_code == 200:\n", + " print(\"LoRA adapter loaded successfully.\", response.json())\n", + "else:\n", + " print(\"Failed to load LoRA adapter.\", response.json())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check output again:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = f\"http://127.0.0.1:{port}\"\n", + "json_data = {\n", + " \"text\": [\n", + " \"List 3 countries and their capitals.\",\n", + " \"List 3 countries and their capitals.\",\n", + " ],\n", + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", + " # The first input uses lora0, and the second input uses lora1\n", + " \"lora_path\": [\"lora0\", \"lora1\"],\n", + "}\n", + "response = requests.post(\n", + " url + \"/generate\",\n", + " json=json_data,\n", + ")\n", + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, { diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 6320a6e61..d7c5ff520 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -176,8 +176,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| +| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False | | `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | -| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None | +| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 96102d1ef..85fd24616 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -186,9 +186,9 @@ class LoRAManager: ) if incompatible: raise ValueError( - f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration." - "We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, " - "You can specify expected configs via --max_lora_rank and --enable_lora_modules." + f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " + "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " + "included in `--enable_lora_modules`." ) def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 7ba07f675..631d23f17 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -574,7 +574,7 @@ class TokenizerManager: "The server is not configured to enable custom logit processor. " "Please set `--enable-custom-logits-processor` to enable this feature." ) - if self.server_args.lora_paths and obj.lora_path: + if self.server_args.enable_lora and obj.lora_path: self._validate_lora_adapters(obj) def _validate_input_ids_in_vocab( @@ -1037,6 +1037,10 @@ class TokenizerManager: _: Optional[fastapi.Request] = None, ) -> LoadLoRAAdapterReqOutput: self.auto_create_handle_loop() + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. @@ -1060,6 +1064,10 @@ class TokenizerManager: _: Optional[fastapi.Request] = None, ) -> UnloadLoRAAdapterReqOutput: self.auto_create_handle_loop() + if not self.server_args.enable_lora: + raise ValueError( + "LoRA is not enabled. Please set `--enable-lora` to enable LoRA." + ) # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # with dp_size > 1. diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1f654ca7e..520a631c5 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -264,7 +264,7 @@ class CudaGraphRunner: if self.enable_torch_compile: set_torch_compile_config() - if self.model_runner.server_args.lora_paths is not None: + if self.model_runner.server_args.enable_lora: self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) # Graph inputs @@ -510,11 +510,10 @@ class CudaGraphRunner: spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL ) - if self.model_runner.server_args.lora_paths is not None: - # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a - # different logic to handle lora, so we need to set `lora_paths` to a list of non-None - # values if lora is enabled. - lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs + if self.model_runner.server_args.enable_lora: + # It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever + # `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization). + lora_paths = [None] * bs else: lora_paths = None diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index fde60e0e5..6f3ea5474 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -418,7 +418,7 @@ class ForwardBatch: ret._compute_mrope_positions(model_runner, batch) # Init lora information - if model_runner.server_args.lora_paths is not None: + if model_runner.server_args.enable_lora: model_runner.lora_manager.prepare_lora_batch(ret) TboForwardBatchPreparer.prepare( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bbd5b0000..4f0b1d64c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -304,11 +304,7 @@ class ModelRunner: self.apply_torch_tp() # Init lora - # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add - # a new server arg `enable_lora` to control whether to init LoRA manager to be more - # explicit, as it is perfectly valid to start a server with an empty lora_paths and - # load LoRA adapters dynamically later. - if server_args.lora_paths is not None: + if server_args.enable_lora: self.init_lora_manager() # Init memory pool and attention backends @@ -895,7 +891,7 @@ class ModelRunner: max_lora_rank=self.server_args.max_lora_rank, target_modules=self.server_args.lora_target_modules, ) - result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths) + result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {}) if result.success: logger.info( f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 24292bcd7..6464f9f40 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,6 +26,8 @@ from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.utils import ( + LORA_TARGET_ALL_MODULES, + SUPPORTED_LORA_TARGET_MODULES, configure_ipv6, get_device, get_device_memory_capacity, @@ -140,8 +142,9 @@ class ServerArgs: preferred_sampling_params: Optional[str] = None # LoRA + enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None - lora_target_modules: Optional[List[str]] = None + lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_paths: Optional[Union[dict[str, str], List[str]]] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1148,6 +1151,12 @@ class ServerArgs: ) # LoRA + parser.add_argument( + "--enable-lora", + default=ServerArgs.enable_lora, + action="store_true", + help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.", + ) parser.add_argument( "--max-lora-rank", default=ServerArgs.max_lora_rank, @@ -1157,18 +1166,12 @@ class ServerArgs: parser.add_argument( "--lora-target-modules", type=str, - choices=[ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", - ], + choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES], nargs="*", default=None, - help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.", + help="The union set of all target modules where LoRA should be applied. If not specified, " + "it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, " + "all supported modules will be targeted.", ) parser.add_argument( "--lora-paths", @@ -1816,15 +1819,46 @@ class ServerArgs: None, }, "moe_dense_tp_size only support 1 and None currently" - if isinstance(self.lora_paths, list): - lora_paths = self.lora_paths - self.lora_paths = {} - for lora_path in lora_paths: - if "=" in lora_path: - name, path = lora_path.split("=", 1) - self.lora_paths[name] = path - else: - self.lora_paths[lora_path] = lora_path + self.check_lora_server_args() + + def check_lora_server_args(self): + # Enable LoRA if any LoRA paths are provided for backward compatibility. + if self.lora_paths: + if self.enable_lora is None: + self.enable_lora = True + logger.info( + "--enable-lora is set to True because --lora-paths is provided." + ) + elif self.enable_lora is False: + logger.warning( + "--enable-lora is set to False, any provided lora_paths will be ignored." + ) + + if self.enable_lora: + # Normalize lora_paths to a dictionary if it is a list. + if isinstance(self.lora_paths, list): + lora_paths = self.lora_paths + self.lora_paths = {} + for lora_path in lora_paths: + if "=" in lora_path: + name, path = lora_path.split("=", 1) + self.lora_paths[name] = path + else: + self.lora_paths[lora_path] = lora_path + + # Expand target modules + if self.lora_target_modules: + self.lora_target_modules = set(self.lora_target_modules) + if "all" in self.lora_target_modules: + assert ( + len(self.lora_target_modules) == 1 + ), "If 'all' is specified in --lora-target-modules, it should be the only module specified." + self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES) + + # Ensure sufficient information is provided for LoRA initialization. + assert self.lora_paths or ( + self.max_lora_rank and self.lora_target_modules + ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization." def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int): larger_tp = max(decode_tp, prefill_tp) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 7123722eb..23960a8c1 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2892,3 +2892,17 @@ def parse_module_path(module_path, function_name, create_dummy): return final_module, getattr(final_module, function_name) return final_module, None + + +# LoRA-related constants and utilities +SUPPORTED_LORA_TARGET_MODULES = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] + +LORA_TARGET_ALL_MODULES = "all" diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 941940fe0..9ec71c29b 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -507,6 +507,7 @@ class SRTRunner: sleep_on_idle=False, max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, + enable_lora: Optional[bool] = None, ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -547,6 +548,7 @@ class SRTRunner: sleep_on_idle=sleep_on_idle, max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, + enable_lora=enable_lora, **spec_kwargs, ) diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py index 785b44e95..83392b924 100644 --- a/test/srt/models/lora/test_lora_update.py +++ b/test/srt/models/lora/test_lora_update.py @@ -64,8 +64,9 @@ class TestCase: base: str max_loras_per_batch: int all_adapters: List[str] - initial_adapters: List[str] op_sequence: List[Operation] + initial_adapters: Optional[List[str]] = None + enable_lora: Optional[bool] = None max_lora_rank: Optional[int] = None lora_target_modules: Optional[List] = None max_new_tokens: int = 32 @@ -171,6 +172,64 @@ BASIC_TESTS = [ ), ], ), + TestCase( + description="dynamic lora update without initial lora_paths", + base="meta-llama/Llama-3.1-8B-Instruct", + enable_lora=True, + max_lora_rank=256, + lora_target_modules=["all"], + max_loras_per_batch=4, + all_adapters=[ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + ], + op_sequence=[ + Operation( + type=OperationType.LOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + None, + ] + ), + ), + Operation( + type=OperationType.UNLOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + expected_error="not loaded", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + None, + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + None, + ] + ), + ), + ], + ), TestCase( description="dynamic lora update with evictions", base="meta-llama/Llama-3.1-8B-Instruct", @@ -371,7 +430,7 @@ TARGET_MODULE_TESTS = [ Operation( type=OperationType.LOAD, data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -431,7 +490,7 @@ MAX_LORA_RANK_TESTS = [ Operation( type=OperationType.LOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -470,7 +529,7 @@ MAX_LORA_RANK_TESTS = [ Operation( type=OperationType.LOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - expected_error="updating LoRA shapes", + expected_error="incompatible", ), Operation( type=OperationType.FORWARD, @@ -521,6 +580,7 @@ class LoRAUpdateTestSessionBase: lora_paths: list[str], max_loras_per_batch: int, max_lora_rank: Optional[int], + enable_lora: Optional[bool] = None, lora_target_modules: Optional[List[str]] = None, lora_backend: str = "triton", disable_cuda_graph: bool = False, @@ -535,8 +595,9 @@ class LoRAUpdateTestSessionBase: self.lora_backend = lora_backend self.disable_cuda_graph = disable_cuda_graph self.cuda_graph_max_bs = cuda_graph_max_bs + self.enable_lora = enable_lora - self.expected_adapters = set(lora_paths) + self.expected_adapters = set(lora_paths or []) self.handle = None # Will be set in __enter__ def __enter__(self): @@ -596,6 +657,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): disable_cuda_graph=self.disable_cuda_graph, cuda_graph_max_bs=self.cuda_graph_max_bs, disable_radix_cache=True, + enable_lora=self.enable_lora, ) self.handle.__enter__() return self @@ -690,8 +752,6 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): other_args = [ "--cuda-graph-max-bs", str(self.cuda_graph_max_bs), - "--lora-paths", - *self.lora_paths, "--max-loras-per-batch", str(self.max_loras_per_batch), "--lora-backend", @@ -704,6 +764,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): "--mem-fraction-static", str(MEM_FRACTION_STATIC), ] + if self.enable_lora: + other_args.append("--enable-lora") + if self.lora_paths: + other_args.extend(["--lora-paths"] + self.lora_paths) if self.disable_cuda_graph: other_args.append("--disable-cuda-graph") if self.max_lora_rank is not None: @@ -836,6 +900,7 @@ class TestLoRADynamicUpdate(CustomTestCase): initial_adapters: List[str], max_loras_per_batch: int, op_sequence: List[Operation], + enable_lora: Optional[bool] = None, max_lora_rank: Optional[int] = None, lora_target_modules: Optional[List[str]] = None, max_new_tokens: int = 32, @@ -854,6 +919,7 @@ class TestLoRADynamicUpdate(CustomTestCase): max_loras_per_batch=max_loras_per_batch, max_lora_rank=max_lora_rank, lora_target_modules=lora_target_modules, + enable_lora=enable_lora, ) as session: for op in op_sequence: op_type = op.type @@ -903,6 +969,7 @@ class TestLoRADynamicUpdate(CustomTestCase): dynamic_output = self._run_operation_sequence( mode=mode, initial_adapters=test_case.initial_adapters, + enable_lora=test_case.enable_lora, base=test_case.base, max_loras_per_batch=test_case.max_loras_per_batch, op_sequence=test_case.op_sequence, @@ -923,6 +990,7 @@ class TestLoRADynamicUpdate(CustomTestCase): static_output = self._run_operation_sequence( mode=mode, initial_adapters=test_case.all_adapters, + enable_lora=test_case.enable_lora, base=test_case.base, max_loras_per_batch=test_case.max_loras_per_batch, op_sequence=forward_ops, diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f59aed623..d7b4739e3 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -18,7 +18,7 @@ suites = { TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), - TestFile("models/lora/test_lora_update.py", 700), + TestFile("models/lora/test_lora_update.py", 800), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100),