[FEATURE] Add OpenAI-Compatible LoRA Adapter Selection (#11570)
This commit is contained in:
@@ -59,6 +59,17 @@
|
||||
"### Serving Single Adaptor"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Note:** SGLang supports LoRA adapters through two APIs:\n",
|
||||
"\n",
|
||||
"1. **OpenAI-Compatible API** (`/v1/chat/completions`, `/v1/completions`): Use the `model:adapter-name` syntax. See [OpenAI API with LoRA](../basic_usage/openai_api_completions.ipynb#Using-LoRA-Adapters) for examples.\n",
|
||||
"\n",
|
||||
"2. **Native API** (`/generate`): Pass `lora_path` in the request body (shown below)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -379,6 +390,15 @@
|
||||
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### OpenAI-compatible API usage\n",
|
||||
"\n",
|
||||
"You can use LoRA adapters via the OpenAI-compatible APIs by specifying the adapter in the `model` field using the `base-model:adapter-name` syntax (for example, `qwen/qwen2.5-0.5b-instruct:adapter_a`). For more details and examples, see the “Using LoRA Adapters” section in the OpenAI API documentation: [openai_api_completions.ipynb](../basic_usage/openai_api_completions.ipynb).\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -361,6 +361,50 @@
|
||||
"For OpenAI compatible structured outputs API, refer to [Structured Outputs](../advanced_features/structured_outputs.ipynb) for more details.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Using LoRA Adapters\n",
|
||||
"\n",
|
||||
"SGLang supports LoRA (Low-Rank Adaptation) adapters with OpenAI-compatible APIs. You can specify which adapter to use directly in the `model` parameter using the `base-model:adapter-name` syntax.\n",
|
||||
"\n",
|
||||
"**Server Setup:**\n",
|
||||
"```bash\n",
|
||||
"python -m sglang.launch_server \\\n",
|
||||
" --model-path qwen/qwen2.5-0.5b-instruct \\\n",
|
||||
" --enable-lora \\\n",
|
||||
" --lora-paths adapter_a=/path/to/adapter_a adapter_b=/path/to/adapter_b\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"For more details on LoRA serving configuration, see the [LoRA documentation](../advanced_features/lora.ipynb).\n",
|
||||
"\n",
|
||||
"**API Call:**\n",
|
||||
"\n",
|
||||
"(Recommended) Use the `model:adapter` syntax to specify which adapter to use:\n",
|
||||
"```python\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"qwen/qwen2.5-0.5b-instruct:adapter_a\", # ← base-model:adapter-name\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": \"Convert to SQL: show all users\"}],\n",
|
||||
" max_tokens=50,\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"**Backward Compatible: Using `extra_body`**\n",
|
||||
"\n",
|
||||
"The old `extra_body` method is still supported for backward compatibility:\n",
|
||||
"```python\n",
|
||||
"# Backward compatible method\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"qwen/qwen2.5-0.5b-instruct\",\n",
|
||||
" messages=[{\"role\": \"user\", \"content\": \"Convert to SQL: show all users\"}],\n",
|
||||
" extra_body={\"lora_path\": \"adapter_a\"}, # ← old method\n",
|
||||
" max_tokens=50,\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"**Note:** When both `model:adapter` and `extra_body[\"lora_path\"]` are specified, the `model:adapter` syntax takes precedence."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -1,37 +1,67 @@
|
||||
# launch server
|
||||
# python -m sglang.launch_server --model mistralai/Mistral-7B-Instruct-v0.3 --lora-paths /home/ying/test_lora lora1=/home/ying/test_lora_1 lora2=/home/ying/test_lora_2 --disable-radix --disable-cuda-graph --max-loras-per-batch 4
|
||||
"""
|
||||
OpenAI-compatible LoRA adapter usage with SGLang.
|
||||
|
||||
# send requests
|
||||
# lora_path[i] specifies the LoRA used for text[i], so make sure they have the same length
|
||||
# use None to specify base-only prompt, e.x. "lora_path": [None, "/home/ying/test_lora"]
|
||||
import json
|
||||
Server Setup:
|
||||
python -m sglang.launch_server \\
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \\
|
||||
--enable-lora \\
|
||||
--lora-paths sql=/path/to/sql python=/path/to/python
|
||||
"""
|
||||
|
||||
import requests
|
||||
import openai
|
||||
|
||||
url = "http://127.0.0.1:30000"
|
||||
json_data = {
|
||||
"text": [
|
||||
"prompt 1",
|
||||
"prompt 2",
|
||||
"prompt 3",
|
||||
"prompt 4",
|
||||
"prompt 5",
|
||||
"prompt 6",
|
||||
"prompt 7",
|
||||
],
|
||||
"sampling_params": {"max_new_tokens": 32},
|
||||
"lora_path": [
|
||||
"/home/ying/test_lora",
|
||||
"lora1",
|
||||
"lora2",
|
||||
"lora1",
|
||||
"lora2",
|
||||
None,
|
||||
None,
|
||||
],
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=json_data,
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||
|
||||
|
||||
def main():
|
||||
print("SGLang OpenAI-Compatible LoRA Examples\n")
|
||||
|
||||
# Example 1: NEW - Adapter in model parameter (OpenAI-compatible)
|
||||
print("1. Chat with LoRA adapter in model parameter:")
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct:sql", # ← adapter:name syntax
|
||||
messages=[{"role": "user", "content": "Convert to SQL: show all users"}],
|
||||
max_tokens=50,
|
||||
)
|
||||
print(f" Response: {response.choices[0].message.content}\n")
|
||||
|
||||
# Example 2: Completions API with adapter
|
||||
print("2. Completion with LoRA adapter:")
|
||||
response = client.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct:python",
|
||||
prompt="def fibonacci(n):",
|
||||
max_tokens=50,
|
||||
)
|
||||
print(f" Response: {response.choices[0].text}\n")
|
||||
|
||||
# Example 3: OLD - Backward compatible with explicit lora_path
|
||||
print("3. Backward compatible (explicit lora_path):")
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[{"role": "user", "content": "Convert to SQL: show all users"}],
|
||||
extra_body={"lora_path": "sql"},
|
||||
max_tokens=50,
|
||||
)
|
||||
print(f" Response: {response.choices[0].message.content}\n")
|
||||
|
||||
# Example 4: Base model (no adapter)
|
||||
print("4. Base model without adapter:")
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=30,
|
||||
)
|
||||
print(f" Response: {response.choices[0].message.content}\n")
|
||||
|
||||
print("All examples completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(
|
||||
"\nEnsure server is running:\n"
|
||||
" python -m sglang.launch_server --model ... --enable-lora --lora-paths ..."
|
||||
)
|
||||
|
||||
@@ -204,7 +204,10 @@ class BatchResponse(BaseModel):
|
||||
class CompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str = DEFAULT_MODEL_NAME
|
||||
model: str = Field(
|
||||
default=DEFAULT_MODEL_NAME,
|
||||
description="Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.",
|
||||
)
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: bool = False
|
||||
@@ -441,7 +444,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str = DEFAULT_MODEL_NAME
|
||||
model: str = Field(
|
||||
default=DEFAULT_MODEL_NAME,
|
||||
description="Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.",
|
||||
)
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: bool = False
|
||||
@@ -1099,7 +1105,7 @@ class ResponsesResponse(BaseModel):
|
||||
Union[
|
||||
ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall
|
||||
]
|
||||
]
|
||||
],
|
||||
) -> bool:
|
||||
if not items:
|
||||
return False
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import orjson
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -35,6 +35,52 @@ class OpenAIServingBase(ABC):
|
||||
else None
|
||||
)
|
||||
|
||||
def _parse_model_parameter(self, model: str) -> Tuple[str, Optional[str]]:
|
||||
"""Parse 'base-model:adapter-name' syntax to extract LoRA adapter.
|
||||
|
||||
Returns (base_model, adapter_name) or (model, None) if no colon present.
|
||||
"""
|
||||
if ":" not in model:
|
||||
return model, None
|
||||
|
||||
# Split on first colon only to handle model paths with multiple colons
|
||||
parts = model.split(":", 1)
|
||||
base_model = parts[0].strip()
|
||||
adapter_name = parts[1].strip() or None
|
||||
|
||||
return base_model, adapter_name
|
||||
|
||||
def _resolve_lora_path(
|
||||
self,
|
||||
request_model: str,
|
||||
explicit_lora_path: Optional[Union[str, List[Optional[str]]]],
|
||||
) -> Optional[Union[str, List[Optional[str]]]]:
|
||||
"""Resolve LoRA adapter with priority: model parameter > explicit lora_path.
|
||||
|
||||
Returns adapter name or None. Supports both single values and lists (batches).
|
||||
"""
|
||||
_, adapter_from_model = self._parse_model_parameter(request_model)
|
||||
|
||||
# Model parameter adapter takes precedence
|
||||
if adapter_from_model is not None:
|
||||
return adapter_from_model
|
||||
|
||||
# Fall back to explicit lora_path
|
||||
return explicit_lora_path
|
||||
|
||||
def _validate_lora_enabled(self, adapter_name: str) -> None:
|
||||
"""Check that LoRA is enabled before attempting to use an adapter.
|
||||
|
||||
Raises ValueError with actionable guidance if --enable-lora flag is missing.
|
||||
Adapter existence is validated later by TokenizerManager.lora_registry.
|
||||
"""
|
||||
if not self.tokenizer_manager.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
f"LoRA adapter '{adapter_name}' was requested, but LoRA is not enabled. "
|
||||
"Please launch the server with --enable-lora flag and preload adapters "
|
||||
"using --lora-paths or /load_lora_adapter endpoint."
|
||||
)
|
||||
|
||||
async def handle_request(
|
||||
self, request: OpenAIServingRequest, raw_request: Request
|
||||
) -> Union[Any, StreamingResponse, ErrorResponse]:
|
||||
|
||||
@@ -164,6 +164,17 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
# Extract custom labels from raw request headers
|
||||
custom_labels = self.extract_custom_labels(raw_request)
|
||||
|
||||
# Resolve LoRA adapter from model parameter or explicit lora_path
|
||||
lora_path = self._resolve_lora_path(request.model, request.lora_path)
|
||||
if lora_path:
|
||||
first_adapter = (
|
||||
lora_path
|
||||
if isinstance(lora_path, str)
|
||||
else next((a for a in lora_path if a), None)
|
||||
)
|
||||
if first_adapter:
|
||||
self._validate_lora_enabled(first_adapter)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=processed_messages.image_data,
|
||||
@@ -176,7 +187,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
stream=request.stream,
|
||||
return_text_in_logprobs=True,
|
||||
modalities=processed_messages.modalities,
|
||||
lora_path=request.lora_path,
|
||||
lora_path=lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
|
||||
@@ -93,6 +93,17 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
# Extract custom labels from raw request headers
|
||||
custom_labels = self.extract_custom_labels(raw_request)
|
||||
|
||||
# Resolve LoRA adapter from model parameter or explicit lora_path
|
||||
lora_path = self._resolve_lora_path(request.model, request.lora_path)
|
||||
if lora_path:
|
||||
first_adapter = (
|
||||
lora_path
|
||||
if isinstance(lora_path, str)
|
||||
else next((a for a in lora_path if a), None)
|
||||
)
|
||||
if first_adapter:
|
||||
self._validate_lora_enabled(first_adapter)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params,
|
||||
@@ -101,7 +112,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
logprob_start_len=logprob_start_len,
|
||||
return_text_in_logprobs=True,
|
||||
stream=request.stream,
|
||||
lora_path=request.lora_path,
|
||||
lora_path=lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
|
||||
327
test/srt/lora/test_lora_openai_api.py
Normal file
327
test/srt/lora/test_lora_openai_api.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Unit tests for OpenAI-compatible LoRA API support.
|
||||
|
||||
Tests the model parameter parsing and LoRA adapter resolution logic
|
||||
that enables OpenAI-compatible LoRA adapter selection.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class MockTokenizerManager:
|
||||
"""Mock TokenizerManager for testing."""
|
||||
|
||||
def __init__(self, enable_lora=False):
|
||||
self.server_args = MagicMock(spec=ServerArgs)
|
||||
self.server_args.enable_lora = enable_lora
|
||||
self.server_args.tokenizer_metrics_allowed_custom_labels = None
|
||||
|
||||
|
||||
class ConcreteServingBase(OpenAIServingBase):
|
||||
"""Concrete implementation for testing abstract base class."""
|
||||
|
||||
def _request_id_prefix(self) -> str:
|
||||
return "test-"
|
||||
|
||||
def _convert_to_internal_request(self, request, raw_request=None):
|
||||
pass
|
||||
|
||||
def _validate_request(self, request):
|
||||
pass
|
||||
|
||||
|
||||
class TestParseModelParameter(unittest.TestCase):
|
||||
"""Test _parse_model_parameter method."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
||||
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
||||
|
||||
def test_model_without_adapter(self):
|
||||
"""Test parsing model without adapter returns None for adapter."""
|
||||
base_model, adapter = self.serving._parse_model_parameter("llama-3.1-8B")
|
||||
self.assertEqual(base_model, "llama-3.1-8B")
|
||||
self.assertIsNone(adapter)
|
||||
|
||||
def test_model_with_adapter(self):
|
||||
"""Test parsing model with adapter extracts both parts."""
|
||||
base_model, adapter = self.serving._parse_model_parameter(
|
||||
"llama-3.1-8B:sql-expert"
|
||||
)
|
||||
self.assertEqual(base_model, "llama-3.1-8B")
|
||||
self.assertEqual(adapter, "sql-expert")
|
||||
|
||||
def test_model_with_path_and_adapter(self):
|
||||
"""Test parsing model path with slashes and adapter."""
|
||||
base_model, adapter = self.serving._parse_model_parameter(
|
||||
"meta-llama/Llama-3.1-8B-Instruct:adapter-name"
|
||||
)
|
||||
self.assertEqual(base_model, "meta-llama/Llama-3.1-8B-Instruct")
|
||||
self.assertEqual(adapter, "adapter-name")
|
||||
|
||||
def test_model_with_multiple_colons(self):
|
||||
"""Test that only first colon is used for splitting."""
|
||||
base_model, adapter = self.serving._parse_model_parameter("model:adapter:extra")
|
||||
self.assertEqual(base_model, "model")
|
||||
self.assertEqual(adapter, "adapter:extra")
|
||||
|
||||
def test_model_with_whitespace(self):
|
||||
"""Test that whitespace is stripped from both parts."""
|
||||
base_model, adapter = self.serving._parse_model_parameter(
|
||||
" model-name : adapter-name "
|
||||
)
|
||||
self.assertEqual(base_model, "model-name")
|
||||
self.assertEqual(adapter, "adapter-name")
|
||||
|
||||
def test_model_with_empty_adapter(self):
|
||||
"""Test model ending with colon returns None for adapter."""
|
||||
base_model, adapter = self.serving._parse_model_parameter("model-name:")
|
||||
self.assertEqual(base_model, "model-name")
|
||||
self.assertIsNone(adapter)
|
||||
|
||||
def test_model_with_only_spaces_after_colon(self):
|
||||
"""Test model with only whitespace after colon returns None for adapter."""
|
||||
base_model, adapter = self.serving._parse_model_parameter("model-name: ")
|
||||
self.assertEqual(base_model, "model-name")
|
||||
self.assertIsNone(adapter)
|
||||
|
||||
|
||||
class TestResolveLoraPath(unittest.TestCase):
|
||||
"""Test _resolve_lora_path method."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
||||
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
||||
|
||||
def test_no_adapter_specified(self):
|
||||
"""Test when neither model nor explicit lora_path has adapter."""
|
||||
result = self.serving._resolve_lora_path("model-name", None)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_adapter_in_model_only(self):
|
||||
"""Test adapter from model parameter when no explicit path."""
|
||||
result = self.serving._resolve_lora_path("model:sql-expert", None)
|
||||
self.assertEqual(result, "sql-expert")
|
||||
|
||||
def test_adapter_in_explicit_only(self):
|
||||
"""Test adapter from explicit lora_path when not in model."""
|
||||
result = self.serving._resolve_lora_path("model-name", "python-expert")
|
||||
self.assertEqual(result, "python-expert")
|
||||
|
||||
def test_model_parameter_takes_precedence(self):
|
||||
"""Test model parameter adapter takes precedence over explicit."""
|
||||
result = self.serving._resolve_lora_path("model:sql-expert", "python-expert")
|
||||
self.assertEqual(result, "sql-expert")
|
||||
|
||||
def test_with_list_explicit_lora_path(self):
|
||||
"""Test that explicit list is returned when no model adapter."""
|
||||
explicit = ["adapter1", "adapter2", None]
|
||||
result = self.serving._resolve_lora_path("model-name", explicit)
|
||||
self.assertEqual(result, explicit)
|
||||
|
||||
def test_model_adapter_overrides_list(self):
|
||||
"""Test model adapter overrides even when explicit is a list."""
|
||||
result = self.serving._resolve_lora_path(
|
||||
"model:sql-expert", ["adapter1", "adapter2"]
|
||||
)
|
||||
self.assertEqual(result, "sql-expert")
|
||||
|
||||
def test_complex_model_name_with_adapter(self):
|
||||
"""Test resolution with complex model name."""
|
||||
result = self.serving._resolve_lora_path(
|
||||
"org/model-v2.1:adapter-name", "other-adapter"
|
||||
)
|
||||
self.assertEqual(result, "adapter-name")
|
||||
|
||||
|
||||
class TestValidateLoraEnabled(unittest.TestCase):
|
||||
"""Test _validate_lora_enabled method."""
|
||||
|
||||
def test_validation_passes_when_lora_enabled(self):
|
||||
"""Test validation passes when LoRA is enabled."""
|
||||
tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
||||
serving = ConcreteServingBase(tokenizer_manager)
|
||||
|
||||
# Should not raise
|
||||
try:
|
||||
serving._validate_lora_enabled("sql-expert")
|
||||
except ValueError:
|
||||
self.fail("_validate_lora_enabled raised ValueError unexpectedly")
|
||||
|
||||
def test_validation_fails_when_lora_disabled(self):
|
||||
"""Test validation fails with helpful message when LoRA is disabled."""
|
||||
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
||||
serving = ConcreteServingBase(tokenizer_manager)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
serving._validate_lora_enabled("sql-expert")
|
||||
|
||||
error_message = str(context.exception)
|
||||
self.assertIn("sql-expert", error_message)
|
||||
self.assertIn("--enable-lora", error_message)
|
||||
self.assertIn("not enabled", error_message)
|
||||
|
||||
def test_validation_error_mentions_adapter_name(self):
|
||||
"""Test that error message includes the requested adapter name."""
|
||||
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
||||
serving = ConcreteServingBase(tokenizer_manager)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
serving._validate_lora_enabled("my-custom-adapter")
|
||||
|
||||
self.assertIn("my-custom-adapter", str(context.exception))
|
||||
|
||||
|
||||
class TestIntegrationScenarios(unittest.TestCase):
|
||||
"""Integration tests for common usage scenarios."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
||||
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
||||
|
||||
def test_openai_compatible_usage(self):
|
||||
"""Test typical OpenAI-compatible usage pattern."""
|
||||
# User specifies adapter in model parameter
|
||||
model = "meta-llama/Llama-3.1-8B:sql-expert"
|
||||
explicit_lora = None
|
||||
|
||||
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
||||
self.assertEqual(lora_path, "sql-expert")
|
||||
|
||||
# Validation should pass
|
||||
self.serving._validate_lora_enabled(lora_path)
|
||||
|
||||
def test_backward_compatible_usage(self):
|
||||
"""Test backward-compatible usage with explicit lora_path."""
|
||||
model = "meta-llama/Llama-3.1-8B"
|
||||
explicit_lora = "sql-expert"
|
||||
|
||||
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
||||
self.assertEqual(lora_path, "sql-expert")
|
||||
|
||||
# Validation should pass
|
||||
self.serving._validate_lora_enabled(lora_path)
|
||||
|
||||
def test_base_model_usage(self):
|
||||
"""Test using base model without any adapter."""
|
||||
model = "meta-llama/Llama-3.1-8B"
|
||||
explicit_lora = None
|
||||
|
||||
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
||||
self.assertIsNone(lora_path)
|
||||
|
||||
# No validation needed when no adapter
|
||||
|
||||
def test_batch_request_scenario(self):
|
||||
"""Test batch request with list of adapters."""
|
||||
model = "meta-llama/Llama-3.1-8B" # No adapter in model
|
||||
explicit_lora = ["sql-expert", "python-expert", None]
|
||||
|
||||
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
||||
self.assertEqual(lora_path, explicit_lora)
|
||||
|
||||
# Validate first adapter in list
|
||||
if isinstance(lora_path, list) and lora_path[0]:
|
||||
self.serving._validate_lora_enabled(lora_path[0])
|
||||
|
||||
def test_adapter_in_model_overrides_batch_list(self):
|
||||
"""Test that adapter in model parameter overrides batch list."""
|
||||
model = "meta-llama/Llama-3.1-8B:preferred-adapter"
|
||||
explicit_lora = ["adapter1", "adapter2"]
|
||||
|
||||
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
||||
self.assertEqual(lora_path, "preferred-adapter")
|
||||
|
||||
def test_error_when_lora_not_enabled(self):
|
||||
"""Test comprehensive error flow when LoRA is not enabled."""
|
||||
# Setup server without LoRA enabled
|
||||
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
||||
serving = ConcreteServingBase(tokenizer_manager)
|
||||
|
||||
# User tries to use adapter
|
||||
model = "meta-llama/Llama-3.1-8B:sql-expert"
|
||||
lora_path = serving._resolve_lora_path(model, None)
|
||||
|
||||
# Should get helpful error
|
||||
with self.assertRaises(ValueError) as context:
|
||||
serving._validate_lora_enabled(lora_path)
|
||||
|
||||
error = str(context.exception)
|
||||
self.assertIn("--enable-lora", error)
|
||||
self.assertIn("sql-expert", error)
|
||||
|
||||
|
||||
class TestEdgeCases(unittest.TestCase):
|
||||
"""Test edge cases and error conditions."""
|
||||
|
||||
def setUp(self):
|
||||
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
||||
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
||||
|
||||
def test_empty_string_model(self):
|
||||
"""Test handling of empty string model."""
|
||||
base, adapter = self.serving._parse_model_parameter("")
|
||||
self.assertEqual(base, "")
|
||||
self.assertIsNone(adapter)
|
||||
|
||||
def test_only_colon(self):
|
||||
"""Test model parameter that is just a colon."""
|
||||
base, adapter = self.serving._parse_model_parameter(":")
|
||||
self.assertEqual(base, "")
|
||||
self.assertIsNone(adapter)
|
||||
|
||||
def test_empty_list_lora_path(self):
|
||||
"""Test validation with empty list doesn't crash."""
|
||||
lora_path = self.serving._resolve_lora_path("model-name", [])
|
||||
# Empty list is falsy, so validation won't be called
|
||||
self.assertEqual(lora_path, [])
|
||||
|
||||
def test_list_with_none_first(self):
|
||||
"""Test validation finds first non-None adapter in list."""
|
||||
lora_path = self.serving._resolve_lora_path("model-name", [None, "adapter2"])
|
||||
self.assertEqual(lora_path, [None, "adapter2"])
|
||||
# In actual usage, validation would find "adapter2"
|
||||
|
||||
def test_list_all_none(self):
|
||||
"""Test validation with list of all None values."""
|
||||
lora_path = self.serving._resolve_lora_path("model-name", [None, None])
|
||||
self.assertEqual(lora_path, [None, None])
|
||||
# In actual usage, no validation would occur (no non-None adapters)
|
||||
|
||||
def test_unicode_in_adapter_name(self):
|
||||
"""Test Unicode characters in adapter name."""
|
||||
base, adapter = self.serving._parse_model_parameter("model:adapter-名前")
|
||||
self.assertEqual(base, "model")
|
||||
self.assertEqual(adapter, "adapter-名前")
|
||||
|
||||
def test_special_characters_in_adapter(self):
|
||||
"""Test special characters in adapter name."""
|
||||
base, adapter = self.serving._parse_model_parameter("model:adapter_v2.1-final")
|
||||
self.assertEqual(base, "model")
|
||||
self.assertEqual(adapter, "adapter_v2.1-final")
|
||||
|
||||
def test_none_as_explicit_lora_path(self):
|
||||
"""Test None as explicit lora_path is handled correctly."""
|
||||
result = self.serving._resolve_lora_path("model:adapter", None)
|
||||
self.assertEqual(result, "adapter")
|
||||
|
||||
def test_empty_string_as_explicit_lora_path(self):
|
||||
"""Test empty string as explicit lora_path."""
|
||||
result = self.serving._resolve_lora_path("model-name", "")
|
||||
self.assertEqual(result, "")
|
||||
|
||||
def test_validation_with_empty_adapter_name(self):
|
||||
"""Test validation with empty adapter name still raises error."""
|
||||
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
||||
serving = ConcreteServingBase(tokenizer_manager)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
serving._validate_lora_enabled("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
278
test/srt/openai_server/features/test_lora_openai_compatible.py
Normal file
278
test/srt/openai_server/features/test_lora_openai_compatible.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
End-to-end tests for OpenAI-compatible LoRA adapter usage.
|
||||
|
||||
Tests the model:adapter syntax and backward compatibility with explicit lora_path.
|
||||
|
||||
Usage:
|
||||
python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_model_adapter_syntax
|
||||
python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_explicit_lora_path
|
||||
python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_priority_model_over_explicit
|
||||
python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_base_model_no_adapter
|
||||
python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_completions_api_with_adapter
|
||||
python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRAOpenAICompatible.test_streaming_with_adapter
|
||||
python3 -m unittest openai_server.features.test_lora_openai_compatible.TestLoRADisabledError.test_lora_disabled_error
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
def get_real_lora_adapter() -> str:
|
||||
"""Use a real LoRA adapter from Hugging Face."""
|
||||
return "codelion/Llama-3.2-1B-Instruct-tool-calling-lora"
|
||||
|
||||
|
||||
def setup_class(cls, enable_lora=True):
|
||||
"""Setup test class with LoRA-enabled server."""
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
# Use real LoRA adapter
|
||||
cls.lora_adapter_path = get_real_lora_adapter()
|
||||
|
||||
other_args = [
|
||||
"--max-running-requests",
|
||||
"10",
|
||||
"--disable-radix-cache", # Disable cache for cleaner tests
|
||||
]
|
||||
|
||||
if enable_lora:
|
||||
other_args.extend(
|
||||
[
|
||||
"--enable-lora",
|
||||
"--lora-paths",
|
||||
f"tool_calling={cls.lora_adapter_path}",
|
||||
]
|
||||
)
|
||||
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.client = openai.Client(api_key="EMPTY", base_url=f"{cls.base_url}/v1")
|
||||
|
||||
|
||||
class TestLoRAOpenAICompatible(CustomTestCase):
|
||||
"""Test OpenAI-compatible LoRA adapter usage."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, enable_lora=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_model_adapter_syntax(self):
|
||||
"""Test the new model:adapter syntax works correctly."""
|
||||
response = self.client.chat.completions.create(
|
||||
# ← New OpenAI-compatible syntax
|
||||
model=f"{self.model}:tool_calling",
|
||||
messages=[{"role": "user", "content": "What tools do you have available?"}],
|
||||
max_tokens=50,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(response.choices[0].message.content)
|
||||
self.assertGreater(len(response.choices[0].message.content), 0)
|
||||
print(f"Model adapter syntax response: {response.choices[0].message.content}")
|
||||
|
||||
def test_explicit_lora_path(self):
|
||||
"""Test backward compatibility with explicit lora_path via extra_body."""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": "What tools do you have available?"}],
|
||||
# ← Legacy explicit method
|
||||
extra_body={"lora_path": "tool_calling"},
|
||||
max_tokens=50,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(response.choices[0].message.content)
|
||||
self.assertGreater(len(response.choices[0].message.content), 0)
|
||||
print(f"Explicit lora_path response: {response.choices[0].message.content}")
|
||||
|
||||
def test_priority_model_over_explicit(self):
|
||||
"""Test that model:adapter syntax takes precedence over explicit lora_path."""
|
||||
# This test verifies the priority logic in _resolve_lora_path
|
||||
response = self.client.chat.completions.create(
|
||||
# ← Model specifies tool_calling adapter
|
||||
model=f"{self.model}:tool_calling",
|
||||
messages=[{"role": "user", "content": "What tools do you have available?"}],
|
||||
# ← Both specify same adapter
|
||||
extra_body={"lora_path": "tool_calling"},
|
||||
max_tokens=50,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Should use tool_calling adapter (model parameter takes precedence)
|
||||
self.assertIsNotNone(response.choices[0].message.content)
|
||||
self.assertGreater(len(response.choices[0].message.content), 0)
|
||||
print(f"Priority test response: {response.choices[0].message.content}")
|
||||
|
||||
def test_base_model_no_adapter(self):
|
||||
"""Test using base model without any adapter."""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model, # ← No adapter specified
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
max_tokens=30,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(response.choices[0].message.content)
|
||||
self.assertGreater(len(response.choices[0].message.content), 0)
|
||||
print(f"Base model response: {response.choices[0].message.content}")
|
||||
|
||||
def test_completions_api_with_adapter(self):
|
||||
"""Test completions API with LoRA adapter."""
|
||||
response = self.client.completions.create(
|
||||
model=f"{self.model}:tool_calling", # ← Using model:adapter syntax
|
||||
prompt="What tools do you have available?",
|
||||
max_tokens=50,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(response.choices[0].text)
|
||||
self.assertGreater(len(response.choices[0].text), 0)
|
||||
print(f"Completions API response: {response.choices[0].text}")
|
||||
|
||||
def test_streaming_with_adapter(self):
|
||||
"""Test streaming with LoRA adapter."""
|
||||
stream = self.client.chat.completions.create(
|
||||
model=f"{self.model}:tool_calling",
|
||||
messages=[{"role": "user", "content": "What tools do you have available?"}],
|
||||
max_tokens=50,
|
||||
temperature=0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
collected_content = ""
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
collected_content += chunk.choices[0].delta.content
|
||||
|
||||
self.assertGreater(len(collected_content), 0)
|
||||
print(f"Streaming response: {collected_content}")
|
||||
|
||||
def test_multiple_adapters(self):
|
||||
"""Test using different adapters in sequence."""
|
||||
# Test tool_calling adapter
|
||||
tool_response = self.client.chat.completions.create(
|
||||
model=f"{self.model}:tool_calling",
|
||||
messages=[{"role": "user", "content": "What tools do you have available?"}],
|
||||
max_tokens=30,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Test base model without adapter
|
||||
base_response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": "Hello, how are you?"}],
|
||||
max_tokens=30,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
self.assertIsNotNone(tool_response.choices[0].message.content)
|
||||
self.assertIsNotNone(base_response.choices[0].message.content)
|
||||
print(
|
||||
f"Tool calling adapter response: {tool_response.choices[0].message.content}"
|
||||
)
|
||||
print(f"Base model response: {base_response.choices[0].message.content}")
|
||||
|
||||
|
||||
class TestLoRADisabledError(CustomTestCase):
|
||||
"""Test error handling when LoRA is disabled."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, enable_lora=False) # ← LoRA disabled
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_lora_disabled_error(self):
|
||||
"""Test that using LoRA adapter when LoRA is disabled raises appropriate error."""
|
||||
with self.assertRaises(openai.APIError) as context:
|
||||
self.client.chat.completions.create(
|
||||
model=f"{self.model}:tool_calling", # ← Trying to use adapter
|
||||
messages=[
|
||||
{"role": "user", "content": "What tools do you have available?"}
|
||||
],
|
||||
max_tokens=50,
|
||||
)
|
||||
|
||||
# Verify the error message contains helpful guidance
|
||||
error_message = str(context.exception)
|
||||
self.assertIn("LoRA", error_message)
|
||||
self.assertIn("not enabled", error_message)
|
||||
print(f"Expected error message: {error_message}")
|
||||
|
||||
|
||||
class TestLoRAEdgeCases(CustomTestCase):
|
||||
"""Test edge cases for LoRA adapter usage."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, enable_lora=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_model_with_colon_no_adapter(self):
|
||||
"""Test model parameter ending with colon (empty adapter)."""
|
||||
response = self.client.chat.completions.create(
|
||||
model=f"{self.model}:", # ← Model ends with colon
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=30,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Should work as base model (no adapter)
|
||||
self.assertIsNotNone(response.choices[0].message.content)
|
||||
print(f"Model with colon response: {response.choices[0].message.content}")
|
||||
|
||||
def test_explicit_lora_path_none(self):
|
||||
"""Test explicit lora_path set to None."""
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
extra_body={"lora_path": None}, # ← Explicitly None
|
||||
max_tokens=30,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Should work as base model
|
||||
self.assertIsNotNone(response.choices[0].message.content)
|
||||
print(
|
||||
f"Explicit None lora_path response: {response.choices[0].message.content}"
|
||||
)
|
||||
|
||||
def test_invalid_adapter_name(self):
|
||||
"""Test using non-existent adapter name."""
|
||||
with self.assertRaises(openai.APIError) as context:
|
||||
self.client.chat.completions.create(
|
||||
model=f"{self.model}:nonexistent", # ← Non-existent adapter
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=30,
|
||||
)
|
||||
|
||||
error_message = str(context.exception)
|
||||
print(f"Invalid adapter error: {error_message}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -26,6 +26,7 @@ suites = {
|
||||
TestFile("lora/test_lora_eviction.py", 200),
|
||||
TestFile("lora/test_lora_qwen3.py", 97),
|
||||
TestFile("lora/test_lora_radix_cache.py", 100),
|
||||
TestFile("lora/test_lora_openai_api.py", 30),
|
||||
TestFile("lora/test_lora_update.py", 400),
|
||||
TestFile("lora/test_multi_lora_backend.py", 60),
|
||||
TestFile("models/test_compressed_tensors_models.py", 42),
|
||||
@@ -51,6 +52,7 @@ suites = {
|
||||
TestFile("openai_server/features/test_openai_server_ebnf.py", 95),
|
||||
TestFile("openai_server/features/test_openai_server_hidden_states.py", 240),
|
||||
TestFile("openai_server/features/test_reasoning_content.py", 89),
|
||||
TestFile("openai_server/features/test_lora_openai_compatible.py", 120),
|
||||
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
|
||||
TestFile("openai_server/function_call/test_tool_choice.py", 226),
|
||||
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
|
||||
|
||||
Reference in New Issue
Block a user