From 852c0578fd3377a980cf2296751424a56052155e Mon Sep 17 00:00:00 2001 From: Neelabh Sinha Date: Tue, 21 Oct 2025 00:44:33 -0700 Subject: [PATCH] [FEATURE] Add OpenAI-Compatible LoRA Adapter Selection (#11570) --- docs/advanced_features/lora.ipynb | 20 ++ docs/basic_usage/openai_api_completions.ipynb | 44 +++ examples/runtime/lora.py | 98 ++++-- .../sglang/srt/entrypoints/openai/protocol.py | 12 +- .../srt/entrypoints/openai/serving_base.py | 48 ++- .../srt/entrypoints/openai/serving_chat.py | 13 +- .../entrypoints/openai/serving_completions.py | 13 +- test/srt/lora/test_lora_openai_api.py | 327 ++++++++++++++++++ .../features/test_lora_openai_compatible.py | 278 +++++++++++++++ test/srt/run_suite.py | 2 + 10 files changed, 815 insertions(+), 40 deletions(-) create mode 100644 test/srt/lora/test_lora_openai_api.py create mode 100644 test/srt/openai_server/features/test_lora_openai_compatible.py diff --git a/docs/advanced_features/lora.ipynb b/docs/advanced_features/lora.ipynb index ff2c32fd6..da25e9882 100644 --- a/docs/advanced_features/lora.ipynb +++ b/docs/advanced_features/lora.ipynb @@ -59,6 +59,17 @@ "### Serving Single Adaptor" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** SGLang supports LoRA adapters through two APIs:\n", + "\n", + "1. **OpenAI-Compatible API** (`/v1/chat/completions`, `/v1/completions`): Use the `model:adapter-name` syntax. See [OpenAI API with LoRA](../basic_usage/openai_api_completions.ipynb#Using-LoRA-Adapters) for examples.\n", + "\n", + "2. **Native API** (`/generate`): Pass `lora_path` in the request body (shown below)." + ] + }, { "cell_type": "code", "execution_count": null, @@ -379,6 +390,15 @@ "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### OpenAI-compatible API usage\n", + "\n", + "You can use LoRA adapters via the OpenAI-compatible APIs by specifying the adapter in the `model` field using the `base-model:adapter-name` syntax (for example, `qwen/qwen2.5-0.5b-instruct:adapter_a`). For more details and examples, see the “Using LoRA Adapters” section in the OpenAI API documentation: [openai_api_completions.ipynb](../basic_usage/openai_api_completions.ipynb).\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/docs/basic_usage/openai_api_completions.ipynb b/docs/basic_usage/openai_api_completions.ipynb index 6b967709f..13b04a5d3 100644 --- a/docs/basic_usage/openai_api_completions.ipynb +++ b/docs/basic_usage/openai_api_completions.ipynb @@ -361,6 +361,50 @@ "For OpenAI compatible structured outputs API, refer to [Structured Outputs](../advanced_features/structured_outputs.ipynb) for more details.\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using LoRA Adapters\n", + "\n", + "SGLang supports LoRA (Low-Rank Adaptation) adapters with OpenAI-compatible APIs. You can specify which adapter to use directly in the `model` parameter using the `base-model:adapter-name` syntax.\n", + "\n", + "**Server Setup:**\n", + "```bash\n", + "python -m sglang.launch_server \\\n", + " --model-path qwen/qwen2.5-0.5b-instruct \\\n", + " --enable-lora \\\n", + " --lora-paths adapter_a=/path/to/adapter_a adapter_b=/path/to/adapter_b\n", + "```\n", + "\n", + "For more details on LoRA serving configuration, see the [LoRA documentation](../advanced_features/lora.ipynb).\n", + "\n", + "**API Call:**\n", + "\n", + "(Recommended) Use the `model:adapter` syntax to specify which adapter to use:\n", + "```python\n", + "response = client.chat.completions.create(\n", + " model=\"qwen/qwen2.5-0.5b-instruct:adapter_a\", # ← base-model:adapter-name\n", + " messages=[{\"role\": \"user\", \"content\": \"Convert to SQL: show all users\"}],\n", + " max_tokens=50,\n", + ")\n", + "```\n", + "\n", + "**Backward Compatible: Using `extra_body`**\n", + "\n", + "The old `extra_body` method is still supported for backward compatibility:\n", + "```python\n", + "# Backward compatible method\n", + "response = client.chat.completions.create(\n", + " model=\"qwen/qwen2.5-0.5b-instruct\",\n", + " messages=[{\"role\": \"user\", \"content\": \"Convert to SQL: show all users\"}],\n", + " extra_body={\"lora_path\": \"adapter_a\"}, # ← old method\n", + " max_tokens=50,\n", + ")\n", + "```\n", + "**Note:** When both `model:adapter` and `extra_body[\"lora_path\"]` are specified, the `model:adapter` syntax takes precedence." + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/examples/runtime/lora.py b/examples/runtime/lora.py index bf3fc2d9e..181dc2315 100644 --- a/examples/runtime/lora.py +++ b/examples/runtime/lora.py @@ -1,37 +1,67 @@ -# launch server -# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4 +""" +OpenAI-compatible LoRA adapter usage with SGLang. -# send requests -# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length -# use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"] -import json +Server Setup: + python -m sglang.launch_server \\ + --model meta-llama/Llama-3.1-8B-Instruct \\ + --enable-lora \\ + --lora-paths sql=/path/to/sql python=/path/to/python +""" -import requests +import openai -url = "http://127.0.0.1:30000" -json_data = { - "text": [ - "prompt 1", - "prompt 2", - "prompt 3", - "prompt 4", - "prompt 5", - "prompt 6", - "prompt 7", - ], - "sampling_params": {"max_new_tokens": 32}, - "lora_path": [ - "/home/ying/test_lora", - "lora1", - "lora2", - "lora1", - "lora2", - None, - None, - ], -} -response = requests.post( - url + "/generate", - json=json_data, -) -print(json.dumps(response.json())) +client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + + +def main(): + print("SGLang OpenAI-Compatible LoRA Examples\n") + + # Example 1: NEW - Adapter in model parameter (OpenAI-compatible) + print("1. Chat with LoRA adapter in model parameter:") + response = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct:sql", # ← adapter:name syntax + messages=[{"role": "user", "content": "Convert to SQL: show all users"}], + max_tokens=50, + ) + print(f" Response: {response.choices[0].message.content}\n") + + # Example 2: Completions API with adapter + print("2. Completion with LoRA adapter:") + response = client.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct:python", + prompt="def fibonacci(n):", + max_tokens=50, + ) + print(f" Response: {response.choices[0].text}\n") + + # Example 3: OLD - Backward compatible with explicit lora_path + print("3. Backward compatible (explicit lora_path):") + response = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[{"role": "user", "content": "Convert to SQL: show all users"}], + extra_body={"lora_path": "sql"}, + max_tokens=50, + ) + print(f" Response: {response.choices[0].message.content}\n") + + # Example 4: Base model (no adapter) + print("4. Base model without adapter:") + response = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=30, + ) + print(f" Response: {response.choices[0].message.content}\n") + + print("All examples completed!") + + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(f"Error: {e}") + print( + "\nEnsure server is running:\n" + " python -m sglang.launch_server --model ... --enable-lora --lora-paths ..." + ) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 50c42a1ff..638e97978 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -204,7 +204,10 @@ class BatchResponse(BaseModel): class CompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create - model: str = DEFAULT_MODEL_NAME + model: str = Field( + default=DEFAULT_MODEL_NAME, + description="Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.", + ) prompt: Union[List[int], List[List[int]], str, List[str]] best_of: Optional[int] = None echo: bool = False @@ -441,7 +444,10 @@ class ChatCompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create messages: List[ChatCompletionMessageParam] - model: str = DEFAULT_MODEL_NAME + model: str = Field( + default=DEFAULT_MODEL_NAME, + description="Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.", + ) frequency_penalty: float = 0.0 logit_bias: Optional[Dict[str, float]] = None logprobs: bool = False @@ -1099,7 +1105,7 @@ class ResponsesResponse(BaseModel): Union[ ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall ] - ] + ], ) -> bool: if not items: return False diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index d42a942f3..aa2d0838e 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -4,7 +4,7 @@ import json import logging import uuid from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import orjson from fastapi import HTTPException, Request @@ -35,6 +35,52 @@ class OpenAIServingBase(ABC): else None ) + def _parse_model_parameter(self, model: str) -> Tuple[str, Optional[str]]: + """Parse 'base-model:adapter-name' syntax to extract LoRA adapter. + + Returns (base_model, adapter_name) or (model, None) if no colon present. + """ + if ":" not in model: + return model, None + + # Split on first colon only to handle model paths with multiple colons + parts = model.split(":", 1) + base_model = parts[0].strip() + adapter_name = parts[1].strip() or None + + return base_model, adapter_name + + def _resolve_lora_path( + self, + request_model: str, + explicit_lora_path: Optional[Union[str, List[Optional[str]]]], + ) -> Optional[Union[str, List[Optional[str]]]]: + """Resolve LoRA adapter with priority: model parameter > explicit lora_path. + + Returns adapter name or None. Supports both single values and lists (batches). + """ + _, adapter_from_model = self._parse_model_parameter(request_model) + + # Model parameter adapter takes precedence + if adapter_from_model is not None: + return adapter_from_model + + # Fall back to explicit lora_path + return explicit_lora_path + + def _validate_lora_enabled(self, adapter_name: str) -> None: + """Check that LoRA is enabled before attempting to use an adapter. + + Raises ValueError with actionable guidance if --enable-lora flag is missing. + Adapter existence is validated later by TokenizerManager.lora_registry. + """ + if not self.tokenizer_manager.server_args.enable_lora: + raise ValueError( + f"LoRA adapter '{adapter_name}' was requested, but LoRA is not enabled. " + "Please launch the server with --enable-lora flag and preload adapters " + "using --lora-paths or /load_lora_adapter endpoint." + ) + async def handle_request( self, request: OpenAIServingRequest, raw_request: Request ) -> Union[Any, StreamingResponse, ErrorResponse]: diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 719fa2814..9529d3dbd 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -164,6 +164,17 @@ class OpenAIServingChat(OpenAIServingBase): # Extract custom labels from raw request headers custom_labels = self.extract_custom_labels(raw_request) + # Resolve LoRA adapter from model parameter or explicit lora_path + lora_path = self._resolve_lora_path(request.model, request.lora_path) + if lora_path: + first_adapter = ( + lora_path + if isinstance(lora_path, str) + else next((a for a in lora_path if a), None) + ) + if first_adapter: + self._validate_lora_enabled(first_adapter) + adapted_request = GenerateReqInput( **prompt_kwargs, image_data=processed_messages.image_data, @@ -176,7 +187,7 @@ class OpenAIServingChat(OpenAIServingBase): stream=request.stream, return_text_in_logprobs=True, modalities=processed_messages.modalities, - lora_path=request.lora_path, + lora_path=lora_path, bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index aaf3b097c..b6c8d7432 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -93,6 +93,17 @@ class OpenAIServingCompletion(OpenAIServingBase): # Extract custom labels from raw request headers custom_labels = self.extract_custom_labels(raw_request) + # Resolve LoRA adapter from model parameter or explicit lora_path + lora_path = self._resolve_lora_path(request.model, request.lora_path) + if lora_path: + first_adapter = ( + lora_path + if isinstance(lora_path, str) + else next((a for a in lora_path if a), None) + ) + if first_adapter: + self._validate_lora_enabled(first_adapter) + adapted_request = GenerateReqInput( **prompt_kwargs, sampling_params=sampling_params, @@ -101,7 +112,7 @@ class OpenAIServingCompletion(OpenAIServingBase): logprob_start_len=logprob_start_len, return_text_in_logprobs=True, stream=request.stream, - lora_path=request.lora_path, + lora_path=lora_path, bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, bootstrap_room=request.bootstrap_room, diff --git a/test/srt/lora/test_lora_openai_api.py b/test/srt/lora/test_lora_openai_api.py new file mode 100644 index 000000000..4f5ac5303 --- /dev/null +++ b/test/srt/lora/test_lora_openai_api.py @@ -0,0 +1,327 @@ +""" +Unit tests for OpenAI-compatible LoRA API support. + +Tests the model parameter parsing and LoRA adapter resolution logic +that enables OpenAI-compatible LoRA adapter selection. +""" + +import unittest +from unittest.mock import MagicMock, Mock + +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.server_args import ServerArgs + + +class MockTokenizerManager: + """Mock TokenizerManager for testing.""" + + def __init__(self, enable_lora=False): + self.server_args = MagicMock(spec=ServerArgs) + self.server_args.enable_lora = enable_lora + self.server_args.tokenizer_metrics_allowed_custom_labels = None + + +class ConcreteServingBase(OpenAIServingBase): + """Concrete implementation for testing abstract base class.""" + + def _request_id_prefix(self) -> str: + return "test-" + + def _convert_to_internal_request(self, request, raw_request=None): + pass + + def _validate_request(self, request): + pass + + +class TestParseModelParameter(unittest.TestCase): + """Test _parse_model_parameter method.""" + + def setUp(self): + self.tokenizer_manager = MockTokenizerManager(enable_lora=True) + self.serving = ConcreteServingBase(self.tokenizer_manager) + + def test_model_without_adapter(self): + """Test parsing model without adapter returns None for adapter.""" + base_model, adapter = self.serving._parse_model_parameter("llama-3.1-8B") + self.assertEqual(base_model, "llama-3.1-8B") + self.assertIsNone(adapter) + + def test_model_with_adapter(self): + """Test parsing model with adapter extracts both parts.""" + base_model, adapter = self.serving._parse_model_parameter( + "llama-3.1-8B:sql-expert" + ) + self.assertEqual(base_model, "llama-3.1-8B") + self.assertEqual(adapter, "sql-expert") + + def test_model_with_path_and_adapter(self): + """Test parsing model path with slashes and adapter.""" + base_model, adapter = self.serving._parse_model_parameter( + "meta-llama/Llama-3.1-8B-Instruct:adapter-name" + ) + self.assertEqual(base_model, "meta-llama/Llama-3.1-8B-Instruct") + self.assertEqual(adapter, "adapter-name") + + def test_model_with_multiple_colons(self): + """Test that only first colon is used for splitting.""" + base_model, adapter = self.serving._parse_model_parameter("model:adapter:extra") + self.assertEqual(base_model, "model") + self.assertEqual(adapter, "adapter:extra") + + def test_model_with_whitespace(self): + """Test that whitespace is stripped from both parts.""" + base_model, adapter = self.serving._parse_model_parameter( + " model-name : adapter-name " + ) + self.assertEqual(base_model, "model-name") + self.assertEqual(adapter, "adapter-name") + + def test_model_with_empty_adapter(self): + """Test model ending with colon returns None for adapter.""" + base_model, adapter = self.serving._parse_model_parameter("model-name:") + self.assertEqual(base_model, "model-name") + self.assertIsNone(adapter) + + def test_model_with_only_spaces_after_colon(self): + """Test model with only whitespace after colon returns None for adapter.""" + base_model, adapter = self.serving._parse_model_parameter("model-name: ") + self.assertEqual(base_model, "model-name") + self.assertIsNone(adapter) + + +class TestResolveLoraPath(unittest.TestCase): + """Test _resolve_lora_path method.""" + + def setUp(self): + self.tokenizer_manager = MockTokenizerManager(enable_lora=True) + self.serving = ConcreteServingBase(self.tokenizer_manager) + + def test_no_adapter_specified(self): + """Test when neither model nor explicit lora_path has adapter.""" + result = self.serving._resolve_lora_path("model-name", None) + self.assertIsNone(result) + + def test_adapter_in_model_only(self): + """Test adapter from model parameter when no explicit path.""" + result = self.serving._resolve_lora_path("model:sql-expert", None) + self.assertEqual(result, "sql-expert") + + def test_adapter_in_explicit_only(self): + """Test adapter from explicit lora_path when not in model.""" + result = self.serving._resolve_lora_path("model-name", "python-expert") + self.assertEqual(result, "python-expert") + + def test_model_parameter_takes_precedence(self): + """Test model parameter adapter takes precedence over explicit.""" + result = self.serving._resolve_lora_path("model:sql-expert", "python-expert") + self.assertEqual(result, "sql-expert") + + def test_with_list_explicit_lora_path(self): + """Test that explicit list is returned when no model adapter.""" + explicit = ["adapter1", "adapter2", None] + result = self.serving._resolve_lora_path("model-name", explicit) + self.assertEqual(result, explicit) + + def test_model_adapter_overrides_list(self): + """Test model adapter overrides even when explicit is a list.""" + result = self.serving._resolve_lora_path( + "model:sql-expert", ["adapter1", "adapter2"] + ) + self.assertEqual(result, "sql-expert") + + def test_complex_model_name_with_adapter(self): + """Test resolution with complex model name.""" + result = self.serving._resolve_lora_path( + "org/model-v2.1:adapter-name", "other-adapter" + ) + self.assertEqual(result, "adapter-name") + + +class TestValidateLoraEnabled(unittest.TestCase): + """Test _validate_lora_enabled method.""" + + def test_validation_passes_when_lora_enabled(self): + """Test validation passes when LoRA is enabled.""" + tokenizer_manager = MockTokenizerManager(enable_lora=True) + serving = ConcreteServingBase(tokenizer_manager) + + # Should not raise + try: + serving._validate_lora_enabled("sql-expert") + except ValueError: + self.fail("_validate_lora_enabled raised ValueError unexpectedly") + + def test_validation_fails_when_lora_disabled(self): + """Test validation fails with helpful message when LoRA is disabled.""" + tokenizer_manager = MockTokenizerManager(enable_lora=False) + serving = ConcreteServingBase(tokenizer_manager) + + with self.assertRaises(ValueError) as context: + serving._validate_lora_enabled("sql-expert") + + error_message = str(context.exception) + self.assertIn("sql-expert", error_message) + self.assertIn("--enable-lora", error_message) + self.assertIn("not enabled", error_message) + + def test_validation_error_mentions_adapter_name(self): + """Test that error message includes the requested adapter name.""" + tokenizer_manager = MockTokenizerManager(enable_lora=False) + serving = ConcreteServingBase(tokenizer_manager) + + with self.assertRaises(ValueError) as context: + serving._validate_lora_enabled("my-custom-adapter") + + self.assertIn("my-custom-adapter", str(context.exception)) + + +class TestIntegrationScenarios(unittest.TestCase): + """Integration tests for common usage scenarios.""" + + def setUp(self): + self.tokenizer_manager = MockTokenizerManager(enable_lora=True) + self.serving = ConcreteServingBase(self.tokenizer_manager) + + def test_openai_compatible_usage(self): + """Test typical OpenAI-compatible usage pattern.""" + # User specifies adapter in model parameter + model = "meta-llama/Llama-3.1-8B:sql-expert" + explicit_lora = None + + lora_path = self.serving._resolve_lora_path(model, explicit_lora) + self.assertEqual(lora_path, "sql-expert") + + # Validation should pass + self.serving._validate_lora_enabled(lora_path) + + def test_backward_compatible_usage(self): + """Test backward-compatible usage with explicit lora_path.""" + model = "meta-llama/Llama-3.1-8B" + explicit_lora = "sql-expert" + + lora_path = self.serving._resolve_lora_path(model, explicit_lora) + self.assertEqual(lora_path, "sql-expert") + + # Validation should pass + self.serving._validate_lora_enabled(lora_path) + + def test_base_model_usage(self): + """Test using base model without any adapter.""" + model = "meta-llama/Llama-3.1-8B" + explicit_lora = None + + lora_path = self.serving._resolve_lora_path(model, explicit_lora) + self.assertIsNone(lora_path) + + # No validation needed when no adapter + + def test_batch_request_scenario(self): + """Test batch request with list of adapters.""" + model = "meta-llama/Llama-3.1-8B" # No adapter in model + explicit_lora = ["sql-expert", "python-expert", None] + + lora_path = self.serving._resolve_lora_path(model, explicit_lora) + self.assertEqual(lora_path, explicit_lora) + + # Validate first adapter in list + if isinstance(lora_path, list) and lora_path[0]: + self.serving._validate_lora_enabled(lora_path[0]) + + def test_adapter_in_model_overrides_batch_list(self): + """Test that adapter in model parameter overrides batch list.""" + model = "meta-llama/Llama-3.1-8B:preferred-adapter" + explicit_lora = ["adapter1", "adapter2"] + + lora_path = self.serving._resolve_lora_path(model, explicit_lora) + self.assertEqual(lora_path, "preferred-adapter") + + def test_error_when_lora_not_enabled(self): + """Test comprehensive error flow when LoRA is not enabled.""" + # Setup server without LoRA enabled + tokenizer_manager = MockTokenizerManager(enable_lora=False) + serving = ConcreteServingBase(tokenizer_manager) + + # User tries to use adapter + model = "meta-llama/Llama-3.1-8B:sql-expert" + lora_path = serving._resolve_lora_path(model, None) + + # Should get helpful error + with self.assertRaises(ValueError) as context: + serving._validate_lora_enabled(lora_path) + + error = str(context.exception) + self.assertIn("--enable-lora", error) + self.assertIn("sql-expert", error) + + +class TestEdgeCases(unittest.TestCase): + """Test edge cases and error conditions.""" + + def setUp(self): + self.tokenizer_manager = MockTokenizerManager(enable_lora=True) + self.serving = ConcreteServingBase(self.tokenizer_manager) + + def test_empty_string_model(self): + """Test handling of empty string model.""" + base, adapter = self.serving._parse_model_parameter("") + self.assertEqual(base, "") + self.assertIsNone(adapter) + + def test_only_colon(self): + """Test model parameter that is just a colon.""" + base, adapter = self.serving._parse_model_parameter(":") + self.assertEqual(base, "") + self.assertIsNone(adapter) + + def test_empty_list_lora_path(self): + """Test validation with empty list doesn't crash.""" + lora_path = self.serving._resolve_lora_path("model-name", []) + # Empty list is falsy, so validation won't be called + self.assertEqual(lora_path, []) + + def test_list_with_none_first(self): + """Test validation finds first non-None adapter in list.""" + lora_path = self.serving._resolve_lora_path("model-name", [None, "adapter2"]) + self.assertEqual(lora_path, [None, "adapter2"]) + # In actual usage, validation would find "adapter2" + + def test_list_all_none(self): + """Test validation with list of all None values.""" + lora_path = self.serving._resolve_lora_path("model-name", [None, None]) + self.assertEqual(lora_path, [None, None]) + # In actual usage, no validation would occur (no non-None adapters) + + def test_unicode_in_adapter_name(self): + """Test Unicode characters in adapter name.""" + base, adapter = self.serving._parse_model_parameter("model:adapter-名前") + self.assertEqual(base, "model") + self.assertEqual(adapter, "adapter-名前") + + def test_special_characters_in_adapter(self): + """Test special characters in adapter name.""" + base, adapter = self.serving._parse_model_parameter("model:adapter_v2.1-final") + self.assertEqual(base, "model") + self.assertEqual(adapter, "adapter_v2.1-final") + + def test_none_as_explicit_lora_path(self): + """Test None as explicit lora_path is handled correctly.""" + result = self.serving._resolve_lora_path("model:adapter", None) + self.assertEqual(result, "adapter") + + def test_empty_string_as_explicit_lora_path(self): + """Test empty string as explicit lora_path.""" + result = self.serving._resolve_lora_path("model-name", "") + self.assertEqual(result, "") + + def test_validation_with_empty_adapter_name(self): + """Test validation with empty adapter name still raises error.""" + tokenizer_manager = MockTokenizerManager(enable_lora=False) + serving = ConcreteServingBase(tokenizer_manager) + + with self.assertRaises(ValueError): + serving._validate_lora_enabled("") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/openai_server/features/test_lora_openai_compatible.py b/test/srt/openai_server/features/test_lora_openai_compatible.py new file mode 100644 index 000000000..e38b62e77 --- /dev/null +++ b/test/srt/openai_server/features/test_lora_openai_compatible.py @@ -0,0 +1,278 @@ +""" +End-to-end tests for OpenAI-compatible LoRA adapter usage. + +Tests the model:adapter syntax and backward compatibility with explicit lora_path. + +Usage: + python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_model_adapter_syntax + python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_explicit_lora_path + python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_priority_model_over_explicit + python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_base_model_no_adapter + python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_completions_api_with_adapter + python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_streaming_with_adapter + python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRADisabledError.test_lora_disabled_error +""" + +import unittest + +import openai + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +def get_real_lora_adapter() -> str: + """Use a real LoRA adapter from Hugging Face.""" + return "codelion/Llama-3.2-1B-Instruct-tool-calling-lora" + + +def setup_class(cls, enable_lora=True): + """Setup test class with LoRA-enabled server.""" + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + # Use real LoRA adapter + cls.lora_adapter_path = get_real_lora_adapter() + + other_args = [ + "--max-running-requests", + "10", + "--disable-radix-cache", # Disable cache for cleaner tests + ] + + if enable_lora: + other_args.extend( + [ + "--enable-lora", + "--lora-paths", + f"tool_calling={cls.lora_adapter_path}", + ] + ) + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1") + + +class TestLoRAOpenAICompatible(CustomTestCase): + """Test OpenAI-compatible LoRA adapter usage.""" + + @classmethod + def setUpClass(cls): + setup_class(cls, enable_lora=True) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_model_adapter_syntax(self): + """Test the new model:adapter syntax works correctly.""" + response = self.client.chat.completions.create( + # ← New OpenAI-compatible syntax + model=f"{self.model}:tool_calling", + messages=[{"role": "user", "content": "What tools do you have available?"}], + max_tokens=50, + temperature=0, + ) + + self.assertIsNotNone(response.choices[0].message.content) + self.assertGreater(len(response.choices[0].message.content), 0) + print(f"Model adapter syntax response: {response.choices[0].message.content}") + + def test_explicit_lora_path(self): + """Test backward compatibility with explicit lora_path via extra_body.""" + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "What tools do you have available?"}], + # ← Legacy explicit method + extra_body={"lora_path": "tool_calling"}, + max_tokens=50, + temperature=0, + ) + + self.assertIsNotNone(response.choices[0].message.content) + self.assertGreater(len(response.choices[0].message.content), 0) + print(f"Explicit lora_path response: {response.choices[0].message.content}") + + def test_priority_model_over_explicit(self): + """Test that model:adapter syntax takes precedence over explicit lora_path.""" + # This test verifies the priority logic in _resolve_lora_path + response = self.client.chat.completions.create( + # ← Model specifies tool_calling adapter + model=f"{self.model}:tool_calling", + messages=[{"role": "user", "content": "What tools do you have available?"}], + # ← Both specify same adapter + extra_body={"lora_path": "tool_calling"}, + max_tokens=50, + temperature=0, + ) + + # Should use tool_calling adapter (model parameter takes precedence) + self.assertIsNotNone(response.choices[0].message.content) + self.assertGreater(len(response.choices[0].message.content), 0) + print(f"Priority test response: {response.choices[0].message.content}") + + def test_base_model_no_adapter(self): + """Test using base model without any adapter.""" + response = self.client.chat.completions.create( + model=self.model, # ← No adapter specified + messages=[{"role": "user", "content": "Hello, how are you?"}], + max_tokens=30, + temperature=0, + ) + + self.assertIsNotNone(response.choices[0].message.content) + self.assertGreater(len(response.choices[0].message.content), 0) + print(f"Base model response: {response.choices[0].message.content}") + + def test_completions_api_with_adapter(self): + """Test completions API with LoRA adapter.""" + response = self.client.completions.create( + model=f"{self.model}:tool_calling", # ← Using model:adapter syntax + prompt="What tools do you have available?", + max_tokens=50, + temperature=0, + ) + + self.assertIsNotNone(response.choices[0].text) + self.assertGreater(len(response.choices[0].text), 0) + print(f"Completions API response: {response.choices[0].text}") + + def test_streaming_with_adapter(self): + """Test streaming with LoRA adapter.""" + stream = self.client.chat.completions.create( + model=f"{self.model}:tool_calling", + messages=[{"role": "user", "content": "What tools do you have available?"}], + max_tokens=50, + temperature=0, + stream=True, + ) + + collected_content = "" + for chunk in stream: + if chunk.choices[0].delta.content: + collected_content += chunk.choices[0].delta.content + + self.assertGreater(len(collected_content), 0) + print(f"Streaming response: {collected_content}") + + def test_multiple_adapters(self): + """Test using different adapters in sequence.""" + # Test tool_calling adapter + tool_response = self.client.chat.completions.create( + model=f"{self.model}:tool_calling", + messages=[{"role": "user", "content": "What tools do you have available?"}], + max_tokens=30, + temperature=0, + ) + + # Test base model without adapter + base_response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Hello, how are you?"}], + max_tokens=30, + temperature=0, + ) + + self.assertIsNotNone(tool_response.choices[0].message.content) + self.assertIsNotNone(base_response.choices[0].message.content) + print( + f"Tool calling adapter response: {tool_response.choices[0].message.content}" + ) + print(f"Base model response: {base_response.choices[0].message.content}") + + +class TestLoRADisabledError(CustomTestCase): + """Test error handling when LoRA is disabled.""" + + @classmethod + def setUpClass(cls): + setup_class(cls, enable_lora=False) # ← LoRA disabled + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_lora_disabled_error(self): + """Test that using LoRA adapter when LoRA is disabled raises appropriate error.""" + with self.assertRaises(openai.APIError) as context: + self.client.chat.completions.create( + model=f"{self.model}:tool_calling", # ← Trying to use adapter + messages=[ + {"role": "user", "content": "What tools do you have available?"} + ], + max_tokens=50, + ) + + # Verify the error message contains helpful guidance + error_message = str(context.exception) + self.assertIn("LoRA", error_message) + self.assertIn("not enabled", error_message) + print(f"Expected error message: {error_message}") + + +class TestLoRAEdgeCases(CustomTestCase): + """Test edge cases for LoRA adapter usage.""" + + @classmethod + def setUpClass(cls): + setup_class(cls, enable_lora=True) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_model_with_colon_no_adapter(self): + """Test model parameter ending with colon (empty adapter).""" + response = self.client.chat.completions.create( + model=f"{self.model}:", # ← Model ends with colon + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=30, + temperature=0, + ) + + # Should work as base model (no adapter) + self.assertIsNotNone(response.choices[0].message.content) + print(f"Model with colon response: {response.choices[0].message.content}") + + def test_explicit_lora_path_none(self): + """Test explicit lora_path set to None.""" + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Hello!"}], + extra_body={"lora_path": None}, # ← Explicitly None + max_tokens=30, + temperature=0, + ) + + # Should work as base model + self.assertIsNotNone(response.choices[0].message.content) + print( + f"Explicit None lora_path response: {response.choices[0].message.content}" + ) + + def test_invalid_adapter_name(self): + """Test using non-existent adapter name.""" + with self.assertRaises(openai.APIError) as context: + self.client.chat.completions.create( + model=f"{self.model}:nonexistent", # ← Non-existent adapter + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=30, + ) + + error_message = str(context.exception) + print(f"Invalid adapter error: {error_message}") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 29f332800..80b03fe27 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -26,6 +26,7 @@ suites = { TestFile("lora/test_lora_eviction.py", 200), TestFile("lora/test_lora_qwen3.py", 97), TestFile("lora/test_lora_radix_cache.py", 100), + TestFile("lora/test_lora_openai_api.py", 30), TestFile("lora/test_lora_update.py", 400), TestFile("lora/test_multi_lora_backend.py", 60), TestFile("models/test_compressed_tensors_models.py", 42), @@ -51,6 +52,7 @@ suites = { TestFile("openai_server/features/test_openai_server_ebnf.py", 95), TestFile("openai_server/features/test_openai_server_hidden_states.py", 240), TestFile("openai_server/features/test_reasoning_content.py", 89), + TestFile("openai_server/features/test_lora_openai_compatible.py", 120), TestFile("openai_server/function_call/test_openai_function_calling.py", 60), TestFile("openai_server/function_call/test_tool_choice.py", 226), TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),