diff --git a/docs/backend/lora.ipynb b/docs/backend/lora.ipynb index 4967b9c75..733f75178 100644 --- a/docs/backend/lora.ipynb +++ b/docs/backend/lora.ipynb @@ -381,6 +381,78 @@ "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### LoRA GPU Pinning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Another advanced option is to specify adapters as `pinned` during loading. When an adapter is pinned, it is permanently assigned to one of the available GPU pool slots (as configured by `--max-loras-per-batch`) and will not be evicted from GPU memory during runtime. Instead, it remains resident until it is explicitly unloaded.\n", + "\n", + "This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n", + "\n", + "In the example below, we unload `lora1` and reload it as a `pinned` adapter:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " url + \"/unload_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora1\",\n", + " },\n", + ")\n", + "\n", + "response = requests.post(\n", + " url + \"/load_lora_adapter\",\n", + " json={\n", + " \"lora_name\": \"lora1\",\n", + " \"lora_path\": lora1,\n", + " \"pinned\": True, # Pin the adapter to GPU\n", + " },\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Verify that the result is identical as before:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "url = f\"http://127.0.0.1:{port}\"\n", + "json_data = {\n", + " \"text\": [\n", + " \"List 3 countries and their capitals.\",\n", + " \"List 3 countries and their capitals.\",\n", + " ],\n", + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", + " # The first input uses lora0, and the second input uses lora1\n", + " \"lora_path\": [\"lora0\", \"lora1\"],\n", + "}\n", + "response = requests.post(\n", + " url + \"/generate\",\n", + " json=json_data,\n", + ")\n", + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", + "print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 73f0f76d0..c09a128b5 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -492,12 +492,13 @@ class Engine(EngineBase): self.tokenizer_manager.get_weights_by_name(obj, None) ) - def load_lora_adapter(self, lora_name: str, lora_path: str): + def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False): """Load a new LoRA adapter without re-launching the engine.""" obj = LoadLoRAAdapterReqInput( lora_name=lora_name, lora_path=lora_path, + pinned=pinned, ) loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index e4fe1d0d1..e9fdd0a11 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -144,6 +144,7 @@ class LoRAManager: # keep metadata for displayed messages self.lora_refs[lora_ref.lora_id] = lora_ref + self.num_pinned_loras += int(lora_ref.pinned) except Exception as e: return self.create_lora_update_result( success=False, @@ -157,13 +158,22 @@ class LoRAManager: Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible. """ + # Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration. memory_pool = getattr(self, "memory_pool", None) incompatible = memory_pool and not memory_pool.can_support(lora_config) if incompatible: raise ValueError( - f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " - "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " - "included in `--enable_lora_modules`." + f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current " + "LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured " + "`--max-lora-rank` and that the target modules are included in `--lora-target-modules`." + ) + + # Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation. + if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1: + raise ValueError( + f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots " + "in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your " + "`--max-loras-per-batch` or load it as unpinned LoRA adapters." ) def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult: @@ -172,15 +182,17 @@ class LoRAManager: delete the corresponding LoRA modules. """ - adapter = self.configs.get(lora_ref.lora_id, None) + adapter = self.configs.get(lora_ref.lora_id) + lora_ref = self.lora_refs.get(lora_ref.lora_id) assert ( - adapter is not None + adapter is not None and lora_ref is not None ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend." try: del self.configs[lora_ref.lora_id] del self.loras[lora_ref.lora_id] del self.lora_refs[lora_ref.lora_id] + self.num_pinned_loras -= int(lora_ref.pinned) except Exception as e: return self.create_lora_update_result( success=False, @@ -189,11 +201,49 @@ class LoRAManager: return self.create_lora_update_result(success=True) + def validate_lora_batch(self, lora_ids: set[str]) -> bool: + """ + Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool. + """ + if len(lora_ids) > self.max_loras_per_batch: + return False + + # skip pinned LoRA check if no pinned LoRA adapters are loaded. + if self.num_pinned_loras == 0: + return True + + # counting the number of pinned LoRA adapters in the batch. + pinned_loras_in_batch = 0 + for lora_id in lora_ids: + if lora_id is not None: + lora_ref = self.lora_refs.get(lora_id) + assert ( + lora_ref is not None + ), f"LoRA ID {lora_id} not found in lora_refs." + pinned_loras_in_batch += int(lora_ref.pinned) + + assert pinned_loras_in_batch <= self.num_pinned_loras, ( + f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters " + f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic." + ) + + required_slots = len(lora_ids) - pinned_loras_in_batch + mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras + + return required_slots <= mem_pool_vacancy + def prepare_lora_batch(self, forward_batch: ForwardBatch): + # Load active loras into lora memory pool cur_uids = set(forward_batch.lora_ids) + assert len(cur_uids) <= self.max_loras_per_batch - self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) + self.memory_pool.prepare_lora_batch( + cur_uids=cur_uids, + lora_adapters=self.loras, + lora_modules=self.lora_modules, + lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation. + ) # set up batch info shared by all lora modules bs = forward_batch.batch_size @@ -366,6 +416,9 @@ class LoRAManager: # Mapping from LoRA ID to LoRARef object. self.lora_refs: Dict[str, LoRARef] = {} + # Count of pinned LoRA adapters. + self.num_pinned_loras: int = 0 + if lora_paths: for lora_ref in lora_paths.values(): result = self.load_lora_adapter(lora_ref) @@ -399,7 +452,7 @@ class LoRAManager: self.max_lora_rank = max_lora_rank else: self.max_lora_rank = max( - [x.hf_config["r"] for x in self.configs.values()], + [x.r for x in self.configs.values()], default=0, ) diff --git a/python/sglang/srt/lora/lora_registry.py b/python/sglang/srt/lora/lora_registry.py index bb2fc5659..082f9a2d3 100644 --- a/python/sglang/srt/lora/lora_registry.py +++ b/python/sglang/srt/lora/lora_registry.py @@ -28,14 +28,15 @@ class LoRARef: """ Reference record for a LoRA model. - This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID - eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache + This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``. + The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache keys (e.g., radix cache). """ lora_id: str = field(default_factory=lambda: uuid4().hex) lora_name: Optional[str] = None lora_path: Optional[str] = None + pinned: Optional[bool] = None def __post_init__(self): if self.lora_id is None: diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index ae856246d..cc00c7212 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -1,3 +1,4 @@ +import logging from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.lora.layers import BaseLayerWithLoRA from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig +from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.utils import ( ROW_PARALLELISM_LINEAR_LORA_NAMES, LoRAType, @@ -16,6 +18,28 @@ from sglang.srt.lora.utils import ( get_weight_name, ) +logger = logging.getLogger(__name__) + + +class EmptySlot: + """ + Singleton class to represent an empty slot in the memory pool. + This is used to improve readability by not using special str as a placeholder. + """ + + __slots__ = () + + def __repr__(self): + return "|EMPTY|" + + def __new__(cls): + if not hasattr(cls, "_instance"): + cls._instance = super().__new__(cls) + return cls._instance + + +EMPTY_SLOT = EmptySlot() + class LoRAMemoryPool: """Class for memory pool management of lora modules""" @@ -54,9 +78,11 @@ class LoRAMemoryPool: self.uid_to_buffer_id: Dict[Optional[str], int] = {} # Buffer idx -> lora uid in memory pool - # All uids are initialized as empty strings for empty buffer slots + # All uids are initialized as `EmptySlot` for empty buffer slots # Here we don't initialize to None since None is a valid uid - self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch + self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [ + EMPTY_SLOT + ] * self.max_loras_per_batch self.init_buffers(base_model) @@ -154,17 +180,29 @@ class LoRAMemoryPool: cur_uids: Set[Optional[str]], lora_adapters: Dict[str, LoRAAdapter], lora_modules: List[Dict[str, BaseLayerWithLoRA]], + lora_refs: Dict[str, LoRARef], ): def get_available_buffer_slot(): for buffer_id in range(self.max_loras_per_batch): # Prioritize empty slots - if self.buffer_id_to_uid[buffer_id] == "": + if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT: return buffer_id for buffer_id in range(self.max_loras_per_batch): + uid = self.buffer_id_to_uid[buffer_id] + # Evict unneeded lora - if self.buffer_id_to_uid[buffer_id] not in cur_uids: - self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id]) + if uid not in cur_uids: + # Skip pinned LoRAs + # TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future. + if uid is not None: + lora_ref = lora_refs.get(uid) + if lora_ref is not None and lora_ref.pinned: + continue + + self.uid_to_buffer_id.pop(uid) + logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.") + self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT return buffer_id raise ValueError( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1a0cbeadb..546128212 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1082,6 +1082,8 @@ class LoadLoRAAdapterReqInput: lora_name: str # The path of loading. lora_path: str + # Whether to pin the LoRA adapter in memory. + pinned: bool = False # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`. lora_id: Optional[str] = None @@ -1090,6 +1092,7 @@ class LoadLoRAAdapterReqInput: lora_id=self.lora_id, lora_name=self.lora_name, lora_path=self.lora_path, + pinned=self.pinned, ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6f6dee027..6fd6ffe64 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1538,14 +1538,11 @@ class Scheduler( # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: - if ( - self.enable_lora - and len( - lora_set - | set([req.lora_id for req in adder.can_run_list]) - | set([req.lora_id]) - ) - > self.max_loras_per_batch + + if self.enable_lora and not self.tp_worker.can_run_lora_batch( + lora_set + | set([req.lora_id for req in adder.can_run_list]) + | set([req.lora_id]) ): self.running_batch.batch_is_full = True break diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 498f0daef..50ac39f88 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1129,6 +1129,7 @@ class TokenizerManager: new_adapter = LoRARef( lora_name=obj.lora_name, lora_path=obj.lora_path, + pinned=obj.pinned, ) # Trigger the actual loading operation at the backend processes. @@ -1186,7 +1187,7 @@ class TokenizerManager: return result except ValueError as e: - return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e)) + return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e)) async def get_weights_by_name( self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 0b2900b37..77dac1ea6 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -311,3 +311,6 @@ class TpModelWorker: def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): result = self.model_runner.unload_lora_adapter(recv_req.to_ref()) return result + + def can_run_lora_batch(self, lora_ids: list[str]) -> bool: + return self.model_runner.lora_manager.validate_lora_batch(lora_ids) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 76498514d..674a94195 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -288,6 +288,9 @@ class TpModelWorkerClient: def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): return self.worker.unload_lora_adapter(recv_req) + def can_run_lora_batch(self, lora_ids: list[str]) -> bool: + return self.worker.can_run_lora_batch(lora_ids) + def __delete__(self): self.input_queue.put((None, None)) self.copy_queue.put((None, None, None)) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 605214a98..8f8774f2a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2067,21 +2067,23 @@ class ServerArgs: if self.enable_lora: # Normalize lora_paths to a dictionary if it is a list. + # TODO (lifuhuang): support specifying pinned adapters in server_args. if isinstance(self.lora_paths, list): lora_paths = self.lora_paths self.lora_paths = {} for lora_path in lora_paths: if "=" in lora_path: name, path = lora_path.split("=", 1) - self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path) + self.lora_paths[name] = LoRARef( + lora_name=name, lora_path=path, pinned=False + ) else: self.lora_paths[lora_path] = LoRARef( - lora_name=lora_path, - lora_path=lora_path, + lora_name=lora_path, lora_path=lora_path, pinned=False ) elif isinstance(self.lora_paths, dict): self.lora_paths = { - k: LoRARef(lora_name=k, lora_path=v) + k: LoRARef(lora_name=k, lora_path=v, pinned=False) for k, v in self.lora_paths.items() } elif self.lora_paths is None: diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index ee49584a0..ba1519951 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -568,8 +568,8 @@ class SRTRunner: else: self.tokenizer = None - def load_lora_adapter(self, lora_name: str, lora_path: str): - return self.engine.load_lora_adapter(lora_name, lora_path) + def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False): + return self.engine.load_lora_adapter(lora_name, lora_path, pinned) def unload_lora_adapter(self, lora_name: str): return self.engine.unload_lora_adapter(lora_name) diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py index ef5a4c845..5024a9c5d 100644 --- a/test/srt/models/lora/test_lora_update.py +++ b/test/srt/models/lora/test_lora_update.py @@ -231,88 +231,6 @@ BASIC_TESTS = [ ), ], ), - TestCase( - description="dynamic lora update with evictions", - base="meta-llama/Llama-3.1-8B-Instruct", - max_loras_per_batch=1, - all_adapters=[ - "philschmid/code-llama-3-1-8b-text-to-sql-lora", - "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", - "pbevan11/llama-3.1-8b-ocr-correction", - ], - initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], - op_sequence=[ - Operation( - type=OperationType.FORWARD, - data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data( - "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" - ), - expected_error="not loaded", - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), - expected_error="not loaded", - ), - Operation( - type=OperationType.LOAD, - data="pbevan11/llama-3.1-8b-ocr-correction", - ), - Operation( - type=OperationType.UNLOAD, - data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), - expected_error="not loaded", - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), - ), - Operation( - type=OperationType.LOAD, - data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", - ), - Operation( - type=OperationType.LOAD, - data="philschmid/code-llama-3-1-8b-text-to-sql-lora", - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data( - "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" - ), - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data( - "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" - ), - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), - ), - Operation( - type=OperationType.FORWARD, - data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), - ), - ], - ), ] TARGET_MODULE_TESTS = [ TestCase( @@ -593,9 +511,135 @@ MAX_LOADED_LORAS_TESTS = [ ], ), ] +EVICTION_TESTS = [ + TestCase( + description="dynamic lora update with evictions", + base="meta-llama/Llama-3.1-8B-Instruct", + max_loras_per_batch=2, + all_adapters=[ + "lora1=philschmid/code-llama-3-1-8b-text-to-sql-lora", + "lora2=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "lora3=pbevan11/llama-3.1-8b-ocr-correction", + ], + enable_lora=True, + max_lora_rank=256, + lora_target_modules=["all"], + op_sequence=[ + Operation( + type=OperationType.LOAD, + data={ + "lora_name": "lora1", + "lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "pinned": True, + }, + ), + Operation( + type=OperationType.LOAD, + data={ + "lora_name": "lora2", + "lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pinned": True, + }, + expected_error="starvation", + ), + Operation( + type=OperationType.LOAD, + data={ + "lora_name": "lora2", + "lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", + "pinned": False, + }, + ), + Operation( + type=OperationType.LOAD, + data={ + "lora_name": "lora3", + "lora_path": "pbevan11/llama-3.1-8b-ocr-correction", + "pinned": False, + }, + ), + Operation( + type=OperationType.UNLOAD, + data="lora1", + ), + Operation( + type=OperationType.UNLOAD, + data="lora3", + ), + Operation( + type=OperationType.LOAD, + data={ + "lora_name": "lora3", + "lora_path": "pbevan11/llama-3.1-8b-ocr-correction", + "pinned": True, + }, + ), + Operation( + type=OperationType.LOAD, + data={ + "lora_name": "lora1", + "lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "pinned": True, + }, + expected_error="starvation", + ), + Operation( + type=OperationType.LOAD, + data={ + "lora_name": "lora1", + "lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora", + "pinned": False, + }, + ), + # pinned: lora3 + # unpinned: lora1, lora2 + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "lora1", + "lora2", + ] + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "lora1", + "lora3", + ] + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "lora1", + "lora2", + ] + ), + ), + Operation( + type=OperationType.FORWARD, + data=create_batch_data( + [ + "lora1", + "lora2", + None, + ] + ), + ), + ], + ), +] ALL_TESTS = ( - BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS + BASIC_TESTS + + TARGET_MODULE_TESTS + + MAX_LORA_RANK_TESTS + + MAX_LOADED_LORAS_TESTS + + EVICTION_TESTS ) @@ -714,6 +758,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): lora_name: str, lora_path: Optional[str] = None, expected_error: Optional[str] = None, + pinned: bool = False, ): """ Load a LoRA adapter by name and path. @@ -724,17 +769,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): response = self.handle.load_lora_adapter( lora_name=lora_name, lora_path=lora_path, + pinned=pinned, ) if expected_error: - self.testcase.assertFalse(response.success) - self.testcase.assertIn(expected_error, response.error_message) + self.testcase.assertFalse( + response.success, f"Expected failure for {lora_name}, but got success." + ) + self.testcase.assertIn( + expected_error, + response.error_message, + f"Expected error message to contain '{expected_error}', but got '{response.error_message}'", + ) print(f"Received error as expected: {response.error_message}") else: self.expected_adapters.add(lora_name) - self.testcase.assertTrue(response.success) + self.testcase.assertTrue( + response.success, + f"Failed to load LoRA adapter {lora_name}: {response.error_message}", + ) loaded_adapters = set(response.loaded_adapters) print(f"loaded_adapters: {loaded_adapters}") - self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + self.testcase.assertEqual( + loaded_adapters, + self.expected_adapters, + f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}", + ) def unload_lora_adapter(self, lora_name: str): """ @@ -745,11 +804,18 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): response = self.handle.unload_lora_adapter( lora_name=lora_name, ) - self.testcase.assertTrue(response.success) + self.testcase.assertTrue( + response.success, + f"Failed to unload LoRA adapter {lora_name}: {response.error_message}", + ) loaded_adapters = set(response.loaded_adapters) print(f"loaded_adapters: {loaded_adapters}") - self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + self.testcase.assertEqual( + loaded_adapters, + self.expected_adapters, + f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}", + ) def forward( self, @@ -770,13 +836,21 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): except ValueError as e: if expected_error: error_message = str(e) - self.testcase.assertIn(expected_error, error_message) + self.testcase.assertIn( + expected_error, + error_message, + f"Expected error message to contain '{expected_error}', but got '{error_message}'", + ) print(f"Received error as expected: {error_message}") return error_message raise e - self.testcase.assertEqual(len(response.output_strs), len(prompts)) + self.testcase.assertEqual( + len(response.output_strs), + len(prompts), + f"Expected {len(prompts)} outputs, but got {len(response.output_strs)}", + ) output = response.output_strs print(f"output_strs: {output}") @@ -837,6 +911,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): lora_name: str, lora_path: Optional[str] = None, expected_error: Optional[str] = None, + pinned: bool = False, ): """ Load a LoRA adapter by name and path. @@ -846,18 +921,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): response = requests.post( DEFAULT_URL_FOR_TEST + "/load_lora_adapter", - json={"lora_name": lora_name, "lora_path": lora_path}, + json={"lora_name": lora_name, "lora_path": lora_path, "pinned": pinned}, ) if expected_error: - self.testcase.assertEqual(response.status_code, 400) - self.testcase.assertIn(expected_error, response.text) + self.testcase.assertEqual( + response.status_code, + 400, + f"Expected error for {lora_name}, but got success.", + ) + self.testcase.assertIn( + expected_error, + response.text, + f"Expected error message to contain '{expected_error}', but got '{response.text}'", + ) print(f"Received error as expected: {response.text}") else: self.expected_adapters.add(lora_name) - self.testcase.assertTrue(response.ok) + self.testcase.assertTrue( + response.ok, f"Failed to load LoRA adapter {lora_name}: {response.text}" + ) loaded_adapters = set(response.json()["loaded_adapters"]) print(f"loaded_adapters: {loaded_adapters}") - self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + self.testcase.assertEqual( + loaded_adapters, + self.expected_adapters, + f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}", + ) def unload_lora_adapter(self, lora_name: str): """ @@ -869,11 +958,17 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): DEFAULT_URL_FOR_TEST + "/unload_lora_adapter", json={"lora_name": lora_name}, ) - self.testcase.assertTrue(response.ok) + self.testcase.assertTrue( + response.ok, f"Failed to unload LoRA adapter {lora_name}: {response.text}" + ) loaded_adapters = set(response.json()["loaded_adapters"]) print(f"loaded_adapters: {loaded_adapters}") - self.testcase.assertEqual(loaded_adapters, self.expected_adapters) + self.testcase.assertEqual( + loaded_adapters, + self.expected_adapters, + f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}", + ) def forward( self, @@ -898,15 +993,29 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): }, ) if expected_error: - self.testcase.assertEqual(response.status_code, 400) - self.testcase.assertIn(expected_error, response.text) + self.testcase.assertEqual( + response.status_code, + 400, + f"Expected error for forward pass, but got success: {response.text}", + ) + self.testcase.assertIn( + expected_error, + response.text, + f"Expected error message to contain '{expected_error}', but got '{response.text}'", + ) output = response.text print(f"Received error as expected: {response.text}") return output else: - self.testcase.assertTrue(response.ok) + self.testcase.assertTrue( + response.ok, f"Failed to generate text: {response.text}" + ) output = [r["text"] for r in response.json()] - self.testcase.assertEqual(len(output), len(prompts)) + self.testcase.assertEqual( + len(output), + len(prompts), + f"Expected {len(prompts)} outputs, but got {len(output)}", + ) print(f"output_strs: {output}") return output @@ -974,10 +1083,18 @@ class TestLoRADynamicUpdate(CustomTestCase): f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---" ) if op_type == OperationType.LOAD: + if isinstance(data, str): + adapter_info = { + "lora_name": data, + "lora_path": data, + "pinned": False, + } + else: + adapter_info = data + result = session.load_lora_adapter( - lora_name=data, - lora_path=data, expected_error=expected_error, + **adapter_info, ) elif op_type == OperationType.UNLOAD: result = session.unload_lora_adapter(