diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index cf100ecce..8e1eb758c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -240,6 +240,12 @@ class TokenizerManager: revision=server_args.revision, ) + # Initialize loaded loRA adapters with the initial lora paths in the server_args. + # This list will be updated when new LoRA adapters are loaded or unloaded dynamically. + self.loaded_lora_adapters: Dict[str, str] = dict( + self.server_args.lora_paths or {} + ) + # Store states self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} @@ -549,6 +555,8 @@ class TokenizerManager: "The server is not configured to enable custom logit processor. " "Please set `--enable-custom-logits-processor` to enable this feature." ) + if self.server_args.lora_paths and obj.lora_path: + self._validate_lora_adapters(obj) def _validate_input_ids_in_vocab( self, input_ids: List[int], vocab_size: int @@ -662,6 +670,21 @@ class TokenizerManager: "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." ) + def _validate_lora_adapters(self, obj: GenerateReqInput): + """Validate that the requested LoRA adapters are loaded.""" + requested_adapters = ( + set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path} + ) + loaded_adapters = ( + self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set() + ) + unloaded_adapters = requested_adapters - loaded_adapters + if unloaded_adapters: + raise ValueError( + f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n" + f"Loaded adapters: {loaded_adapters}." + ) + def _send_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -988,6 +1011,7 @@ class TokenizerManager: async with self.model_update_lock.writer_lock: result = (await self.update_lora_adapter_communicator(obj))[0] + self.loaded_lora_adapters = result.loaded_adapters return result async def unload_lora_adapter( @@ -1009,6 +1033,7 @@ class TokenizerManager: async with self.model_update_lock.writer_lock: result = (await self.update_lora_adapter_communicator(obj))[0] + self.loaded_lora_adapters = result.loaded_adapters return result async def get_weights_by_name( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 050099a03..73a5845a0 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,7 +20,7 @@ import logging import os import random import tempfile -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.reasoning_parser import ReasoningParser @@ -131,7 +131,7 @@ class ServerArgs: preferred_sampling_params: Optional[str] = None # LoRA - lora_paths: Optional[List[str]] = None + lora_paths: Optional[Union[dict[str, str], List[str]]] = None max_loras_per_batch: int = 8 lora_backend: str = "triton" diff --git a/test/srt/models/lora/test_lora_update.py b/test/srt/models/lora/test_lora_update.py index 587789cf1..4a85758b5 100644 --- a/test/srt/models/lora/test_lora_update.py +++ b/test/srt/models/lora/test_lora_update.py @@ -16,7 +16,7 @@ import multiprocessing as mp import unittest from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import requests import torch @@ -42,14 +42,16 @@ PROMPTS = [ class OperationType(Enum): LOAD = "load" UNLOAD = "unload" - NOOP = "noop" FORWARD = "forward" + EXPECT_ERROR = "expect_error" @dataclass class Operation: + # Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR type: OperationType - data: Optional[str] + # Data associated with the operation. Exact type varies depending on the operation + data: Optional[Any] @dataclass @@ -62,7 +64,7 @@ class TestCase: max_new_tokens: int = 32 -def create_batch_data(adapters: Union[str, list]) -> dict: +def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: if not isinstance(adapters, list): adapters = [adapters] return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters] @@ -80,6 +82,26 @@ TEST_CASES = [ ], initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"], op_sequence=[ + Operation( + type=OperationType.FORWARD, + data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + ), + Operation( + type=OperationType.EXPECT_ERROR, + data=( + create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + "not loaded", + ), + ), + Operation( + type=OperationType.EXPECT_ERROR, + data=( + create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + "not loaded", + ), + ), Operation( type=OperationType.LOAD, data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", @@ -102,6 +124,13 @@ TEST_CASES = [ type=OperationType.UNLOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", ), + Operation( + type=OperationType.EXPECT_ERROR, + data=( + create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + "not loaded", + ), + ), Operation( type=OperationType.FORWARD, data=create_batch_data( @@ -115,6 +144,15 @@ TEST_CASES = [ type=OperationType.UNLOAD, data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", ), + Operation( + type=OperationType.EXPECT_ERROR, + data=( + create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + "not loaded", + ), + ), Operation( type=OperationType.FORWARD, data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), @@ -149,6 +187,22 @@ TEST_CASES = [ type=OperationType.FORWARD, data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), ), + Operation( + type=OperationType.EXPECT_ERROR, + data=( + create_batch_data( + "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" + ), + "not loaded", + ), + ), + Operation( + type=OperationType.EXPECT_ERROR, + data=( + create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), + "not loaded", + ), + ), Operation( type=OperationType.LOAD, data="pbevan11/llama-3.1-8b-ocr-correction", @@ -157,6 +211,13 @@ TEST_CASES = [ type=OperationType.UNLOAD, data="philschmid/code-llama-3-1-8b-text-to-sql-lora", ), + Operation( + type=OperationType.EXPECT_ERROR, + data=( + create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), + "not loaded", + ), + ), Operation( type=OperationType.FORWARD, data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), @@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): prompts: List[str], lora_paths: List[str], max_new_tokens: int = 32, + expected_error: str = None, ): """ Perform a batch forward pass with the current set of loaded LoRA adapters. """ - response = self.handle.batch_forward( - prompts=prompts, - lora_paths=lora_paths, - max_new_tokens=max_new_tokens, - ) - output_strs = response.output_strs + try: + response = self.handle.batch_forward( + prompts=prompts, + lora_paths=lora_paths, + max_new_tokens=max_new_tokens, + ) + except ValueError as e: + if expected_error: + error_message = str(e) + self.testcase.assertIn(expected_error, error_message) + print(f"Received error as expected: {error_message}") + return error_message - print(f"output_strs: {output_strs}") - return output_strs + raise e + + self.testcase.assertEqual(len(response.output_strs), len(prompts)) + output = response.output_strs + print(f"output_strs: {output}") + + return output class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): @@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): prompts: List[str], lora_paths: List[str], max_new_tokens: int = 32, + expected_error: str = None, ): """ Perform a batch forward pass with the current set of loaded LoRA adapters. @@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): }, }, ) - self.testcase.assertTrue(response.ok) - output_strs = [r["text"] for r in response.json()] - - print(f"output_strs: {output_strs}") - return output_strs + if expected_error: + self.testcase.assertEqual(response.status_code, 400) + self.testcase.assertIn(expected_error, response.text) + output = response.text + print(f"Received error as expected: {response.text}") + return output + else: + self.testcase.assertTrue(response.ok) + output = [r["text"] for r in response.json()] + self.testcase.assertEqual(len(output), len(prompts)) + print(f"output_strs: {output}") + return output # Factory function to create the appropriate LoRA test session based on mode @@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase): max_new_tokens=max_new_tokens, ) forward_outputs.append(result) + elif op_type == OperationType.EXPECT_ERROR: + input_data, expected_error = data + prompts, adapters = zip(*input_data) + result = session.forward( + prompts=list(prompts), + lora_paths=list(adapters), + max_new_tokens=max_new_tokens, + expected_error=expected_error, + ) return forward_outputs def test_dynamic_adapter_updates(self): for case_idx, test_case in enumerate(TEST_CASES, start=1): for mode in [ - LoRAUpdateTestSessionMode.SERVER, LoRAUpdateTestSessionMode.ENGINE, + LoRAUpdateTestSessionMode.SERVER, ]: print("=" * 100) print(f"Starting test case {case_idx} in {mode.value} mode.")