From e2ed9d049a34a49654618afdc880e0867378ba22 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sun, 13 Jul 2025 18:36:01 -0700 Subject: [PATCH] Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844) --- docs/backend/lora.ipynb | 239 ++++++++ docs/backend/server_arguments.md | 2 + python/sglang/srt/lora/lora_manager.py | 138 +++-- python/sglang/srt/lora/mem_pool.py | 63 +- python/sglang/srt/lora/utils.py | 17 +- .../sglang/srt/model_executor/model_runner.py | 2 + python/sglang/srt/server_args.py | 24 + python/sglang/test/runners.py | 4 + test/srt/models/lora/test_lora_update.py | 576 +++++++++++++----- test/srt/run_suite.py | 2 +- 10 files changed, 840 insertions(+), 227 deletions(-) diff --git a/docs/backend/lora.ipynb b/docs/backend/lora.ipynb index ffe2f48f6..6c089b654 100644 --- a/docs/backend/lora.ipynb +++ b/docs/backend/lora.ipynb @@ -33,6 +33,10 @@ "\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "\n", + "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", + "\n", + "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n", + "\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "\n", "From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to." @@ -176,6 +180,241 @@ "terminate_process(server_process)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Dynamic LoRA loading" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Basic Usage\n", + "\n", + "Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n", + "\n", + "(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n", + " --cuda-graph-max-bs 2 \\\n", + " --max-loras-per-batch 2 --lora-backend triton \\\n", + " --disable-radix-cache\n", + " \"\"\"\n", + ")\n", + "\n", + "url = f\"http://127.0.0.1:{port}\"\n", + "wait_for_server(url)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/load_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora1\",\n", + " \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n", + " },\n", + ")\n", + "\n", + "if response.status_code == 200:\n", + " print(\"LoRA adapter loaded successfully.\", response.json())\n", + "else:\n", + " print(\"Failed to load LoRA adapter.\", response.json())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/generate\",\n", + " json={\n", + " \"text\": [\n", + " \"List 3 countries and their capitals.\",\n", + " \"List 3 countries and their capitals.\",\n", + " ],\n", + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", + " \"lora_path\": [\"lora0\", \"lora1\"],\n", + " },\n", + ")\n", + "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", + "print(f\"Output from lora1: {response.json()[1]['text']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/unload_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora0\",\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/load_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora2\",\n", + " \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n", + " },\n", + ")\n", + "\n", + "if response.status_code == 200:\n", + " print(\"LoRA adapter loaded successfully.\", response.json())\n", + "else:\n", + " print(\"Failed to load LoRA adapter.\", response.json())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/generate\",\n", + " json={\n", + " \"text\": [\n", + " \"List 3 countries and their capitals.\",\n", + " \"List 3 countries and their capitals.\",\n", + " ],\n", + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", + " \"lora_path\": [\"lora1\", \"lora2\"],\n", + " },\n", + ")\n", + "print(f\"Output from lora1: {response.json()[0]['text']}\")\n", + "print(f\"Output from lora2: {response.json()[1]['text']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Advanced: hosting adapters of different shapes\n", + "\n", + "In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n", + "\n", + "For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", + "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", + "\n", + "\n", + "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", + "# We are adding it here just to demonstrate usage.\n", + "server_process, port = launch_server_cmd(\n", + " f\"\"\"\n", + " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + " --lora-paths lora0={lora0} \\\n", + " --cuda-graph-max-bs 2 \\\n", + " --max-loras-per-batch 2 --lora-backend triton \\\n", + " --disable-radix-cache\n", + " --max-lora-rank 64\n", + " --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n", + " \"\"\"\n", + ")\n", + "\n", + "url = f\"http://127.0.0.1:{port}\"\n", + "wait_for_server(url)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/load_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora1\",\n", + " \"lora_path\": lora1,\n", + " },\n", + ")\n", + "\n", + "if response.status_code == 200:\n", + " print(\"LoRA adapter loaded successfully.\", response.json())\n", + "else:\n", + " print(\"Failed to load LoRA adapter.\", response.json())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = f\"http://127.0.0.1:{port}\"\n", + "json_data = {\n", + " \"text\": [\n", + " \"List 3 countries and their capitals.\",\n", + " \"AI is a field of computer science focused on\",\n", + " ],\n", + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", + " # The first input uses lora0, and the second input uses lora1\n", + " \"lora_path\": [\"lora0\", \"lora1\"],\n", + "}\n", + "response = requests.post(\n", + " url + \"/generate\",\n", + " json=json_data,\n", + ")\n", + "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", + "print(f\"Output from lora1: {response.json()[1]['text']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 4f19321c3..ad9c136c8 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | +| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None | +| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None | ## Kernel backend diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index ca0b62c55..96102d1ef 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -16,7 +16,7 @@ # and "Punica: Multi-Tenant LoRA Serving" import logging -from typing import Dict, Set, Tuple +from typing import Dict, Iterable, Optional, Set, Tuple import torch @@ -53,6 +53,8 @@ class LoRAManager: lora_backend: str = "triton", tp_size: int = 1, tp_rank: int = 0, + max_lora_rank: Optional[int] = None, + target_modules: Optional[Iterable[str]] = None, ): self.base_model: torch.nn.Module = base_model self.base_hf_config: AutoConfig = base_hf_config @@ -62,6 +64,10 @@ class LoRAManager: self.device: torch.device = next(self.base_model.parameters()).device self.tp_size: int = tp_size self.tp_rank: int = tp_rank + self.max_lora_rank: Optional[int] = max_lora_rank + self.target_modules: Optional[Set[str]] = ( + set(target_modules) if target_modules else None + ) # LoRA backend for running sgemm kernels logger.info(f"Using {lora_backend} as backend of LoRA kernels.") @@ -153,7 +159,9 @@ class LoRAManager: error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first." try: - self.configs[lora_name] = LoRAConfig(lora_path) + new_adapter = LoRAConfig(lora_path) + self.validate_new_adapter(lora_name, new_adapter) + self.configs[lora_name] = new_adapter except Exception as e: success = False error_message = ( @@ -168,6 +176,21 @@ class LoRAManager: error_message=error_message, ) + def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig): + """ + Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible. + """ + + incompatible = self.memory_pool and not self.memory_pool.can_support( + lora_config + ) + if incompatible: + raise ValueError( + f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration." + "We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, " + "You can specify expected configs via --max_lora_rank and --enable_lora_modules." + ) + def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: """ Unload LoRA adapters by their names. This will remove the adapters from the memory pool and @@ -214,7 +237,7 @@ class LoRAManager: weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) if lora_path is not None: lora = self.loras[lora_path] - lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] + lora_ranks[weight_indices[i]] = lora.config.r scalings[weight_indices[i]] = lora.scaling # Use pinned memory to avoid synchronizations during host-to-device transfer @@ -319,7 +342,7 @@ class LoRAManager: ) else: weight_name = get_weight_name( - module_name, self.lora_weight_names, LoRAType.LORA_A + module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A ) module.set_lora_info( self.memory_pool.get_tensor( @@ -351,58 +374,67 @@ class LoRAManager: i: {} for i in range(self.base_hf_config.num_hidden_layers) } - # Initialize memory pool - self.memory_pool = LoRAMemoryPool( - self.base_hf_config, - self.max_loras_per_batch, - self.dtype, - self.tp_size, - self.tp_rank, - ) + # The LoRA memory pool that manages the GPU buffers for active LoRA weights. + # It is initialized lazily when the first LoRA adapter is loaded. + self.memory_pool: Optional[LoRAMemoryPool] = None def update_state_from_configs(self): """ Update the internal state of the LoRAManager based on the current `self.configs`. This method should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded). - - This includes: - - Initializing LoRA adapters if they are not already loaded. - - Collect all LoRA weight names based on the current loaded adapters. - - Lazily monkey-patching the base model to use LoRA layers where applicable. - - Preparing the GPU buffer pool for active LoRA weights. """ - # Target module names in huggingface lora configs. - # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} - hf_target_module_names: Set[str] = set() - for config in self.configs.values(): - hf_target_module_names.update(config.target_modules) - max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()]) - # Loads / unloads LoRA adapters based on the latest configs. self.update_lora_adapters() + # Apply the latest LoRA configurations to the internal state for inferencing. + self.apply_lora_configs() - # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed. - # - # Please note that the following update operations are "monotonic" by design, meaning that we update - # multiple places to support the new weight names when the first adapter targeting such weight names - # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer) - # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the - # list of LoRA weight names is expected to be extremely finite and stable. - self.update_lora_weight_names(hf_target_module_names) - self.update_lora_modules(hf_target_module_names) - self.update_memory_buffers(max_lora_dim) + def apply_lora_configs(self): + """ + Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing. - def update_lora_weight_names(self, hf_target_names: Set[str]): + Notes: + - Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as + we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer + LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in + early CY25H2. + """ + + if self.memory_pool is None: + # Infer max_lora_rank and target_modules if not explicitly specified in server args. + if self.target_modules is None: + self.target_modules = set() + for config in self.configs.values(): + self.target_modules.update(config.target_modules) + + if self.max_lora_rank is None: + self.max_lora_rank = max( + [x.hf_config["r"] for x in self.configs.values()], + default=0, + ) + + self.update_lora_weight_names() + self.update_lora_modules() + self.update_memory_buffers() + else: + # No-op if the memory pool can support the current LoRA configurations. + # TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target + # module is changed once FlashInfer backend is deprecated. + assert self.memory_pool.can_support(self.configs.values()), ( + "LoRA memory pool cannot support the current LoRA configuration. " + "This should never happen as we should have validated adapter compatibility. " + "Please create a Github issue to report.", + ) + + def update_lora_weight_names(self): """ Add new LoRA weight names if needed based on the current `self.configs`. """ # Target lora weight names for lora_a and lora_b modules respectively. - for module in hf_target_names: - lora_A, lora_B = get_normalized_lora_weight_names(module) - self.lora_weight_names[0].update(lora_A) - self.lora_weight_names[1].update(lora_B) + lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules) + self.lora_weight_names[0].update(lora_A) + self.lora_weight_names[1].update(lora_B) def update_lora_adapters(self): """ @@ -434,21 +466,23 @@ class LoRAManager: # Additional checks for flashinfer backend # FIXME remove the restrictions after supporting multi-rank for flashinfer backend if self.lora_backend == "flashinfer": - lora_dims = set(x.hf_config["r"] for x in self.configs.values()) + lora_dims = set(x.r for x in self.configs.values()) scalings = set(x.scaling for x in self.loras.values()) assert ( len(lora_dims) == 1 and len(scalings) == 1 ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " - def update_memory_buffers(self, max_lora_dim: int): - """ - Update the LoRA memory pool buffers based on the current LoRA configurations and update - LoRA modules to use the new buffers. This method should be called after the LoRA configurations - are set or updated. - """ - - self.memory_pool.init_buffers( - self.lora_weight_names, self.base_model, max_lora_dim + def update_memory_buffers(self): + """(Re)initialize the LoRA memory pool based on the current configurations.""" + self.memory_pool = LoRAMemoryPool( + base_hf_config=self.base_hf_config, + max_loras_per_batch=self.max_loras_per_batch, + dtype=self.dtype, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + max_lora_rank=self.max_lora_rank, + lora_weight_names=self.lora_weight_names, + base_model=self.base_model, ) def set_lora_module(self, module_name, module): @@ -456,11 +490,11 @@ class LoRAManager: replace_submodule(self.base_model, module_name, lora_module) return lora_module - def update_lora_modules(self, hf_target_names: Set[str]): + def update_lora_modules(self): # Target module names of customized layers defined in python/sglang/srt/layers # e.g., {"qkv_proj", "o_proj"} customized_target_names = get_customized_names_from_hf_names( - hf_target_names, self.base_model + self.target_modules, self.base_model ) for module_name, module in self.base_model.named_modules(): diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 27122ccc4..713b03650 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Set, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -6,10 +6,12 @@ from sglang.srt.distributed import divide from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.lora.layers import BaseLayerWithLoRA from sglang.srt.lora.lora import LoRAAdapter +from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.utils import ( ROW_PARALLELISM_LINEAR_LORA_NAMES, LoRAType, get_hidden_dim, + get_normalized_lora_weight_names, get_stacked_multiply, get_weight_name, ) @@ -25,6 +27,9 @@ class LoRAMemoryPool: dtype: torch.dtype, tp_size: int, tp_rank: int, + max_lora_rank: int, + lora_weight_names: Tuple[Set[str], Set[str]], + base_model: torch.nn.Module, ): self.base_hf_config: AutoConfig = base_hf_config self.num_layer: int = base_hf_config.num_hidden_layers @@ -32,6 +37,10 @@ class LoRAMemoryPool: self.dtype: torch.dtype = dtype self.tp_size: int = tp_size self.tp_rank: int = tp_rank + self.max_lora_rank: int = max_lora_rank + + # lora weight names for LoRA A and B respectively. + self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names # Both A_buffer and B_buffer maps lora weight names to its buffer space. # A_buffer contains num_layer number of row-major tensors with shape @@ -49,6 +58,31 @@ class LoRAMemoryPool: # Here we don't initialize to None since None is a valid uid self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch + self.init_buffers(base_model) + + def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool: + """ + Check if the memory pool can support the given LoRA adapters. + """ + + def _can_support(config: LoRAConfig) -> bool: + """ + Check if the memory pool can support a single LoRA adapter. + """ + if config.r > self.max_lora_rank: + return False + weights_a, weights_b = get_normalized_lora_weight_names( + config.target_modules + ) + return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset( + self.lora_weight_names[1] + ) + + if isinstance(config, LoRAConfig): + return _can_support(config) + else: + return all(_can_support(x) for x in config) + def get_lora_A_shape( self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int ) -> Tuple[int]: @@ -82,25 +116,18 @@ class LoRAMemoryPool: max_lora_dim, ) - def init_buffers( - self, - lora_weight_names: Tuple[Set[str]], - base_model: torch.nn.Module, - max_lora_dim: int, - ): - # lora_weight_names is a set of name pairs indicating each pair of lora modules to load - # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")} - self.lora_weight_names: Tuple[Set[str]] = lora_weight_names + def init_buffers(self, base_model: torch.nn.Module): device = next(base_model.parameters()).device - def update_buffer( + def init_buffer( buffer: Dict[str, List[torch.Tensor]], lora_weight_names: Set[str], get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]], ): - new_weight_names = lora_weight_names - buffer.keys() - for module_name in new_weight_names: - lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim) + for module_name in lora_weight_names: + lora_shape = get_lora_shape_fn( + module_name, base_model, self.max_lora_rank + ) buffer[module_name] = [ torch.empty( lora_shape, @@ -110,15 +137,15 @@ class LoRAMemoryPool: for _ in range(self.num_layer) ] - update_buffer( + init_buffer( self.A_buffer, - lora_weight_names[0], + self.lora_weight_names[0], self.get_lora_A_shape, ) - update_buffer( + init_buffer( self.B_buffer, - lora_weight_names[1], + self.lora_weight_names[1], self.get_lora_B_shape, ) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 2df4a8c14..d440fa70c 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Set, Tuple +from typing import Iterable, Optional, Set, Tuple import torch @@ -106,9 +106,11 @@ def get_hidden_dim( raise NotImplementedError() -def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: +def get_normalized_lora_weight_names( + target_modules: Iterable[str], +) -> Tuple[set[str], set[str]]: """ - Mapping a target module name to names of the normalized LoRA weights. + Mapping a list of target module name to names of the normalized LoRA weights. Returned tuple contains (name for Lora A, name for Lora B) """ params_mapping = { @@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]), "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]), } - stacked = params_mapping.get(name, ([name], [name])) - return stacked + + result = (set(), set()) + for name in target_modules: + lora_a, lora_b = params_mapping.get(name, ([name], [name])) + result[0].update(lora_a) + result[1].update(lora_b) + return result def get_stacked_multiply(module_name: str) -> int: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fe9560497..f70eccd0c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -891,6 +891,8 @@ class ModelRunner: lora_backend=self.server_args.lora_backend, tp_size=self.tp_size, tp_rank=self.tp_rank, + max_lora_rank=self.server_args.max_lora_rank, + target_modules=self.server_args.lora_target_modules, ) result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths) if result.success: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 30191ee08..16ac09b16 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -134,6 +134,8 @@ class ServerArgs: preferred_sampling_params: Optional[str] = None # LoRA + max_lora_rank: Optional[int] = None + lora_target_modules: Optional[List[str]] = None lora_paths: Optional[Union[dict[str, str], List[str]]] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" @@ -1129,6 +1131,28 @@ class ServerArgs: ) # LoRA + parser.add_argument( + "--max-lora-rank", + default=ServerArgs.max_lora_rank, + type=int, + help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.", + ) + parser.add_argument( + "--lora-target-modules", + type=str, + choices=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + nargs="*", + default=None, + help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.", + ) parser.add_argument( "--lora-paths", type=str, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 481bf682d..64a1b34c2 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -505,6 +505,8 @@ class SRTRunner: torchao_config: Optional[str] = None, cuda_graph_max_bs: int = 4, sleep_on_idle=False, + max_lora_rank: Optional[int] = None, + lora_target_modules: Optional[List[str]] = None, ): self.model_type = model_type self.is_generation = model_type == "generation" @@ -543,6 +545,8 @@ class SRTRunner: cuda_graph_max_bs=cuda_graph_max_bs, disable_custom_all_reduce=disable_custom_all_reduce, sleep_on_idle=sleep_on_idle, + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, **spec_kwargs, ) diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py index dc96f24e7..785b44e95 100644 --- a/test/srt/models/lora/test_lora_update.py +++ b/test/srt/models/lora/test_lora_update.py @@ -16,7 +16,7 @@ import multiprocessing as mp import unittest from dataclasses import dataclass from enum import Enum -from typing import Any, List, Optional, Union +from typing import Any, Iterable, List, Optional, Union import requests import torch @@ -27,6 +27,7 @@ from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + is_in_ci, popen_launch_server, ) @@ -45,24 +46,28 @@ class OperationType(Enum): LOAD = "load" UNLOAD = "unload" FORWARD = "forward" - EXPECT_ERROR = "expect_error" @dataclass class Operation: - # Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR + # Operation type, can be LOAD, UNLOAD, FORWARD type: OperationType # Data associated with the operation. Exact type varies depending on the operation data: Optional[Any] + # If the operation is expected to fail, this is the error message to expect + expected_error: Optional[str] = None @dataclass class TestCase: + description: str base: str max_loras_per_batch: int all_adapters: List[str] initial_adapters: List[str] op_sequence: List[Operation] + max_lora_rank: Optional[int] = None + lora_target_modules: Optional[List] = None max_new_tokens: int = 32 @@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters] -TEST_CASES = [ - # basic test, no eviction +BASIC_TESTS = [ TestCase( + description="dynamic lora update with initial lora_paths", base="meta-llama/Llama-3.1-8B-Instruct", max_loras_per_batch=3, all_adapters=[ @@ -89,20 +94,16 @@ TEST_CASES = [ data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), ), Operation( - type=OperationType.EXPECT_ERROR, - data=( - create_batch_data( - "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" - ), - "not loaded", + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" ), + expected_error="not loaded", ), Operation( - type=OperationType.EXPECT_ERROR, - data=( - create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), - "not loaded", - ), + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + expected_error="not loaded", ), Operation( type=OperationType.LOAD, @@ -127,11 +128,9 @@ TEST_CASES = [ data="philschmid/code-llama-3-1-8b-text-to-sql-lora", ), Operation( - type=OperationType.EXPECT_ERROR, - data=( - create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), - "not loaded", - ), + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + expected_error="not loaded", ), Operation( type=OperationType.FORWARD, @@ -147,13 +146,11 @@ TEST_CASES = [ data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", ), Operation( - type=OperationType.EXPECT_ERROR, - data=( - create_batch_data( - "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" - ), - "not loaded", + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" ), + expected_error="not loaded", ), Operation( type=OperationType.FORWARD, @@ -174,8 +171,8 @@ TEST_CASES = [ ), ], ), - # Eviction TestCase( + description="dynamic lora update with evictions", base="meta-llama/Llama-3.1-8B-Instruct", max_loras_per_batch=1, all_adapters=[ @@ -190,20 +187,16 @@ TEST_CASES = [ data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), ), Operation( - type=OperationType.EXPECT_ERROR, - data=( - create_batch_data( - "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" - ), - "not loaded", + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" ), + expected_error="not loaded", ), Operation( - type=OperationType.EXPECT_ERROR, - data=( - create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), - "not loaded", - ), + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + expected_error="not loaded", ), Operation( type=OperationType.LOAD, @@ -214,11 +207,9 @@ TEST_CASES = [ data="philschmid/code-llama-3-1-8b-text-to-sql-lora", ), Operation( - type=OperationType.EXPECT_ERROR, - data=( - create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), - "not loaded", - ), + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + expected_error="not loaded", ), Operation( type=OperationType.FORWARD, @@ -263,6 +254,253 @@ TEST_CASES = [ ], ), ] +TARGET_MODULE_TESTS = [ + TestCase( + description="Test explicitly specified lora-target-modules.", + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=3, + lora_target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + max_lora_rank=64, + all_adapters=[ + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down + "algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate + ], + initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"], + op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "algoprog/fact-generation-llama-3.1-8b-instruct-lora" + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + expected_error="not loaded", + ), + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "algoprog/fact-generation-llama-3.1-8b-instruct-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + None, + ] + ), + ), + ], + ), + TestCase( + description="Test inferred lora-target-modules - start with larger adapter", + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=3, + max_lora_rank=64, + all_adapters=[ + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down + "algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate + ], + initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"], + op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "algoprog/fact-generation-llama-3.1-8b-instruct-lora" + ), + expected_error="not loaded", + ), + Operation( + type=OperationType.LOAD, + data="algoprog/fact-generation-llama-3.1-8b-instruct-lora", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "algoprog/fact-generation-llama-3.1-8b-instruct-lora", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + None, + ] + ), + ), + ], + ), + TestCase( + description="Test inferred lora-target-modules - start with smaller adapter", + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=3, + max_lora_rank=64, + all_adapters=[ + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down + "algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate + ], + initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"], + op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "algoprog/fact-generation-llama-3.1-8b-instruct-lora" + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + expected_error="not loaded", + ), + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + expected_error="updating LoRA shapes", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "algoprog/fact-generation-llama-3.1-8b-instruct-lora", + None, + ] + ), + ), + ], + ), +] +MAX_LORA_RANK_TESTS = [ + TestCase( + description="Test explicitly specified max-lora-rank.", + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=3, + max_lora_rank=32, + all_adapters=[ + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4 + "pbevan11/llama-3.1-8b-ocr-correction", # r = 32 + "philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256 + ], + initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"], + op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + expected_error="not loaded", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + expected_error="not loaded", + ), + Operation( + type=OperationType.LOAD, + data="pbevan11/llama-3.1-8b-ocr-correction", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "pbevan11/llama-3.1-8b-ocr-correction", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + None, + ] + ), + ), + Operation( + type=OperationType.LOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + expected_error="updating LoRA shapes", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "philschmid/code-llama-3-1-8b-text-to-sql-lora", + ), + expected_error="not loaded", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "pbevan11/llama-3.1-8b-ocr-correction", + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + None, + ] + ), + ), + ], + ), + TestCase( + description="test implicitly inferred max-lora-rank", + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=3, + all_adapters=[ + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4 + "pbevan11/llama-3.1-8b-ocr-correction", # r = 32 + "philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256 + ], + initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"], + op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + ), + Operation( + type=OperationType.LOAD, + data="philschmid/code-llama-3-1-8b-text-to-sql-lora", + expected_error="updating LoRA shapes", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + expected_error="not loaded", + ), + Operation( + type=OperationType.LOAD, + data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pbevan11/llama-3.1-8b-ocr-correction", + None, + ] + ), + ), + ], + ), +] +ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS class LoRAUpdateTestSessionMode(Enum): @@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase: testcase: Optional[TestCase], model_path: str, lora_paths: list[str], - max_loras_per_batch: int = 1, + max_loras_per_batch: int, + max_lora_rank: Optional[int], + lora_target_modules: Optional[List[str]] = None, lora_backend: str = "triton", disable_cuda_graph: bool = False, cuda_graph_max_bs: int = 4, @@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase: self.testcase = testcase self.model_path = model_path self.lora_paths = lora_paths + self.max_lora_rank = max_lora_rank + self.lora_target_modules = lora_target_modules self.max_loras_per_batch = max_loras_per_batch self.lora_backend = lora_backend self.disable_cuda_graph = disable_cuda_graph @@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase: # Don't suppress exceptions by default return False - def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): + def load_lora_adapter( + self, + lora_name: str, + lora_path: Optional[str] = None, + expected_error: Optional[str] = None, + ): """ Load a LoRA adapter by name and path. """ @@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase: prompts: List[str], lora_paths: List[str], max_new_tokens: int = 32, + expected_error: Optional[str] = None, ): """ Perform a batch forward pass with the current set of loaded LoRA adapters. @@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): model_path=self.model_path, model_type="generation", lora_paths=self.lora_paths, + max_lora_rank=self.max_lora_rank, + lora_target_modules=self.lora_target_modules, lora_backend=self.lora_backend, torch_dtype=torch.float16, mem_fraction_static=MEM_FRACTION_STATIC, @@ -357,24 +607,32 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): # don't suppress exceptions return False - def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): + def load_lora_adapter( + self, + lora_name: str, + lora_path: Optional[str] = None, + expected_error: Optional[str] = None, + ): """ Load a LoRA adapter by name and path. """ if lora_path is None: lora_path = lora_name - self.expected_adapters.add(lora_name) - response = self.handle.load_lora_adapter( lora_name=lora_name, lora_path=lora_path, ) - self.testcase.assertTrue(response.success) - loaded_adapters = set(response.loaded_adapters) - - print(f"loaded_adapters: {loaded_adapters}") - self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + if expected_error: + self.testcase.assertFalse(response.success) + self.testcase.assertIn(expected_error, response.error_message) + print(f"Received error as expected: {response.error_message}") + else: + self.expected_adapters.add(lora_name) + self.testcase.assertTrue(response.success) + loaded_adapters = set(response.loaded_adapters) + print(f"loaded_adapters: {loaded_adapters}") + self.testcase.assertEqual(loaded_adapters, self.expected_adapters) def unload_lora_adapter(self, lora_name: str): """ @@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): prompts: List[str], lora_paths: List[str], max_new_tokens: int = 32, - expected_error: str = None, + expected_error: Optional[str] = None, ): """ Perform a batch forward pass with the current set of loaded LoRA adapters. @@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ] if self.disable_cuda_graph: other_args.append("--disable-cuda-graph") + if self.max_lora_rank is not None: + other_args.extend(["--max-lora-rank", str(self.max_lora_rank)]) + if self.lora_target_modules is not None: + other_args.extend(["--lora-target-modules"] + self.lora_target_modules) # launch external server self.handle = popen_launch_server( @@ -464,24 +726,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): # don't suppress exceptions return False - def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): + def load_lora_adapter( + self, + lora_name: str, + lora_path: Optional[str] = None, + expected_error: Optional[str] = None, + ): """ Load a LoRA adapter by name and path. """ if lora_path is None: lora_path = lora_name - self.expected_adapters.add(lora_name) - response = requests.post( DEFAULT_URL_FOR_TEST + "/load_lora_adapter", json={"lora_name": lora_name, "lora_path": lora_path}, ) - self.testcase.assertTrue(response.ok) - loaded_adapters = set(response.json()["loaded_adapters"]) - - print(f"loaded_adapters: {loaded_adapters}") - self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + if expected_error: + self.testcase.assertEqual(response.status_code, 400) + self.testcase.assertIn(expected_error, response.text) + print(f"Received error as expected: {response.text}") + else: + self.expected_adapters.add(lora_name) + self.testcase.assertTrue(response.ok) + loaded_adapters = set(response.json()["loaded_adapters"]) + print(f"loaded_adapters: {loaded_adapters}") + self.testcase.assertEqual(loaded_adapters, self.expected_adapters) def unload_lora_adapter(self, lora_name: str): """ @@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): prompts: List[str], lora_paths: List[str], max_new_tokens: int = 32, - expected_error: str = None, + expected_error: Optional[str] = None, ): """ Perform a batch forward pass with the current set of loaded LoRA adapters. @@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): # Factory function to create the appropriate LoRA test session based on mode def LoRAUpdateTestSession( - *, testcase: Optional[TestCase], mode: LoRAUpdateTestSessionMode, - model_path: str, - lora_paths: list[str], - max_loras_per_batch: int = 1, - lora_backend: str = "triton", - disable_cuda_graph: bool = False, - cuda_graph_max_bs: int = 4, + **kwargs: Any, ): - common_kwargs = { - "testcase": testcase, - "model_path": model_path, - "lora_paths": lora_paths, - "max_loras_per_batch": max_loras_per_batch, - "lora_backend": lora_backend, - "disable_cuda_graph": disable_cuda_graph, - "cuda_graph_max_bs": cuda_graph_max_bs, - } - if mode == LoRAUpdateTestSessionMode.ENGINE: - return LoRAUpdateEngineTestSession(**common_kwargs) + return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs) elif mode == LoRAUpdateTestSessionMode.SERVER: - return LoRAUpdateServerTestSession(**common_kwargs) + return LoRAUpdateServerTestSession(testcase=testcase, **kwargs) else: raise ValueError(f"Unrecognized mode: {mode!r}") @@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase): initial_adapters: List[str], max_loras_per_batch: int, op_sequence: List[Operation], + max_lora_rank: Optional[int] = None, + lora_target_modules: Optional[List[str]] = None, max_new_tokens: int = 32, ) -> List[tuple]: """ @@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase): model_path=base, lora_paths=initial_adapters, max_loras_per_batch=max_loras_per_batch, + max_lora_rank=max_lora_rank, + lora_target_modules=lora_target_modules, ) as session: for op in op_sequence: op_type = op.type data = op.data + expected_error = op.expected_error print("-" * 100) print( f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---" @@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase): result = session.load_lora_adapter( lora_name=data, lora_path=data, + expected_error=expected_error, ) elif op_type == OperationType.UNLOAD: result = session.unload_lora_adapter( @@ -615,91 +875,105 @@ class TestLoRADynamicUpdate(CustomTestCase): ) elif op_type == OperationType.FORWARD: prompts, adapters = zip(*data) - result = session.forward( - prompts=list(prompts), - lora_paths=list(adapters), - max_new_tokens=max_new_tokens, - ) - forward_outputs.append(result) - elif op_type == OperationType.EXPECT_ERROR: - input_data, expected_error = data - prompts, adapters = zip(*input_data) result = session.forward( prompts=list(prompts), lora_paths=list(adapters), max_new_tokens=max_new_tokens, expected_error=expected_error, ) + if not expected_error: + forward_outputs.append(result) return forward_outputs - def test_dynamic_adapter_updates(self): - for case_idx, test_case in enumerate(TEST_CASES, start=1): - for mode in [ - LoRAUpdateTestSessionMode.ENGINE, - LoRAUpdateTestSessionMode.SERVER, - ]: - print("=" * 100) - print(f"Starting test case {case_idx} in {mode.value} mode.") - print("=" * 100) + def _run_dynamic_adapter_updates( + self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase] + ): + for case_idx, test_case in enumerate(test_cases, start=1): + print("=" * 100) + print( + f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}" + ) + print("=" * 100) - print( - f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---" - ) - # Test dynamic loading of adapters - # TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora, - # we should fix this in the future https://github.com/sgl-project/sglang/issues/7463. - dynamic_output = self._run_operation_sequence( - mode=mode, - initial_adapters=test_case.initial_adapters, - base=test_case.base, - max_loras_per_batch=test_case.max_loras_per_batch, - op_sequence=test_case.op_sequence, - max_new_tokens=test_case.max_new_tokens, - ) + print( + f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---" + ) + # Test dynamic loading of adapters + dynamic_output = self._run_operation_sequence( + mode=mode, + initial_adapters=test_case.initial_adapters, + base=test_case.base, + max_loras_per_batch=test_case.max_loras_per_batch, + op_sequence=test_case.op_sequence, + max_new_tokens=test_case.max_new_tokens, + max_lora_rank=test_case.max_lora_rank, + lora_target_modules=test_case.lora_target_modules, + ) - # static loading - forward_ops = [ - x for x in test_case.op_sequence if x.type == OperationType.FORWARD - ] + # static loading + forward_ops = [ + x + for x in test_case.op_sequence + if x.type == OperationType.FORWARD and x.expected_error is None + ] - print("=" * 100) - print( - f"\n--- Running static pass with {len(forward_ops)} operations ---" - ) - static_output = self._run_operation_sequence( - mode=mode, - initial_adapters=test_case.all_adapters, - base=test_case.base, - max_loras_per_batch=test_case.max_loras_per_batch, - op_sequence=forward_ops, - max_new_tokens=test_case.max_new_tokens, - ) + print("=" * 100) + print(f"\n--- Running static pass with {len(forward_ops)} operations ---") + static_output = self._run_operation_sequence( + mode=mode, + initial_adapters=test_case.all_adapters, + base=test_case.base, + max_loras_per_batch=test_case.max_loras_per_batch, + op_sequence=forward_ops, + max_new_tokens=test_case.max_new_tokens, + ) - print(f"Dynamic output: {dynamic_output}") - print(f"Static output: {static_output}") - print("=" * 100) + print(f"Dynamic output: {dynamic_output}") + print(f"Static output: {static_output}") + print("=" * 100) + self.assertEqual( + len(dynamic_output), + len(static_output), + f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}", + ) + for i, (dynamic, static) in enumerate( + zip(dynamic_output, static_output), start=1 + ): self.assertEqual( - len(dynamic_output), - len(static_output), - f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}", + len(dynamic), + len(static), + f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}", ) - for i, (dynamic, static) in enumerate( - zip(dynamic_output, static_output), start=1 - ): + for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1): + d_out = d_out.strip() + s_out = s_out.strip() self.assertEqual( - len(dynamic), - len(static), - f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}", + d_out, + s_out, + f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'", ) - for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1): - d_out = d_out.strip() - s_out = s_out.strip() - self.assertEqual( - d_out, - s_out, - f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'", - ) + + def test_dynamic_lora_update_engine(self): + """ + Test dynamic LoRA updates in engine mode. + """ + test_cases = ALL_TESTS + self._run_dynamic_adapter_updates( + mode=LoRAUpdateTestSessionMode.ENGINE, + test_cases=test_cases, + ) + + def test_dynamic_lora_update_server(self): + """ + Test dynamic LoRA updates in server mode. + """ + # In CI, we only run the first test case to save time, as the engine test should be mostly sufficient for ensuring correctness. + test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS + + self._run_dynamic_adapter_updates( + mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases + ) if __name__ == "__main__": diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 971651501..059955f33 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,7 +17,7 @@ suites = { TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250), - TestFile("models/lora/test_lora_update.py", 400), + TestFile("models/lora/test_lora_update.py", 700), TestFile("models/test_embedding_models.py", 73), # TestFile("models/test_clip_models.py", 52), TestFile("models/test_encoder_embedding_models.py", 100),