Files
sglang/test/srt/lora/test_lora_openai_api.py

328 lines
13 KiB
Python

"""
Unit tests for OpenAI-compatible LoRA API support.
Tests the model parameter parsing and LoRA adapter resolution logic
that enables OpenAI-compatible LoRA adapter selection.
"""
import unittest
from unittest.mock import MagicMock, Mock
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.server_args import ServerArgs
class MockTokenizerManager:
"""Mock TokenizerManager for testing."""
def __init__(self, enable_lora=False):
self.server_args = MagicMock(spec=ServerArgs)
self.server_args.enable_lora = enable_lora
self.server_args.tokenizer_metrics_allowed_custom_labels = None
class ConcreteServingBase(OpenAIServingBase):
"""Concrete implementation for testing abstract base class."""
def _request_id_prefix(self) -> str:
return "test-"
def _convert_to_internal_request(self, request, raw_request=None):
pass
def _validate_request(self, request):
pass
class TestParseModelParameter(unittest.TestCase):
"""Test _parse_model_parameter method."""
def setUp(self):
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
self.serving = ConcreteServingBase(self.tokenizer_manager)
def test_model_without_adapter(self):
"""Test parsing model without adapter returns None for adapter."""
base_model, adapter = self.serving._parse_model_parameter("llama-3.1-8B")
self.assertEqual(base_model, "llama-3.1-8B")
self.assertIsNone(adapter)
def test_model_with_adapter(self):
"""Test parsing model with adapter extracts both parts."""
base_model, adapter = self.serving._parse_model_parameter(
"llama-3.1-8B:sql-expert"
)
self.assertEqual(base_model, "llama-3.1-8B")
self.assertEqual(adapter, "sql-expert")
def test_model_with_path_and_adapter(self):
"""Test parsing model path with slashes and adapter."""
base_model, adapter = self.serving._parse_model_parameter(
"meta-llama/Llama-3.1-8B-Instruct:adapter-name"
)
self.assertEqual(base_model, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(adapter, "adapter-name")
def test_model_with_multiple_colons(self):
"""Test that only first colon is used for splitting."""
base_model, adapter = self.serving._parse_model_parameter("model:adapter:extra")
self.assertEqual(base_model, "model")
self.assertEqual(adapter, "adapter:extra")
def test_model_with_whitespace(self):
"""Test that whitespace is stripped from both parts."""
base_model, adapter = self.serving._parse_model_parameter(
" model-name : adapter-name "
)
self.assertEqual(base_model, "model-name")
self.assertEqual(adapter, "adapter-name")
def test_model_with_empty_adapter(self):
"""Test model ending with colon returns None for adapter."""
base_model, adapter = self.serving._parse_model_parameter("model-name:")
self.assertEqual(base_model, "model-name")
self.assertIsNone(adapter)
def test_model_with_only_spaces_after_colon(self):
"""Test model with only whitespace after colon returns None for adapter."""
base_model, adapter = self.serving._parse_model_parameter("model-name: ")
self.assertEqual(base_model, "model-name")
self.assertIsNone(adapter)
class TestResolveLoraPath(unittest.TestCase):
"""Test _resolve_lora_path method."""
def setUp(self):
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
self.serving = ConcreteServingBase(self.tokenizer_manager)
def test_no_adapter_specified(self):
"""Test when neither model nor explicit lora_path has adapter."""
result = self.serving._resolve_lora_path("model-name", None)
self.assertIsNone(result)
def test_adapter_in_model_only(self):
"""Test adapter from model parameter when no explicit path."""
result = self.serving._resolve_lora_path("model:sql-expert", None)
self.assertEqual(result, "sql-expert")
def test_adapter_in_explicit_only(self):
"""Test adapter from explicit lora_path when not in model."""
result = self.serving._resolve_lora_path("model-name", "python-expert")
self.assertEqual(result, "python-expert")
def test_model_parameter_takes_precedence(self):
"""Test model parameter adapter takes precedence over explicit."""
result = self.serving._resolve_lora_path("model:sql-expert", "python-expert")
self.assertEqual(result, "sql-expert")
def test_with_list_explicit_lora_path(self):
"""Test that explicit list is returned when no model adapter."""
explicit = ["adapter1", "adapter2", None]
result = self.serving._resolve_lora_path("model-name", explicit)
self.assertEqual(result, explicit)
def test_model_adapter_overrides_list(self):
"""Test model adapter overrides even when explicit is a list."""
result = self.serving._resolve_lora_path(
"model:sql-expert", ["adapter1", "adapter2"]
)
self.assertEqual(result, "sql-expert")
def test_complex_model_name_with_adapter(self):
"""Test resolution with complex model name."""
result = self.serving._resolve_lora_path(
"org/model-v2.1:adapter-name", "other-adapter"
)
self.assertEqual(result, "adapter-name")
class TestValidateLoraEnabled(unittest.TestCase):
"""Test _validate_lora_enabled method."""
def test_validation_passes_when_lora_enabled(self):
"""Test validation passes when LoRA is enabled."""
tokenizer_manager = MockTokenizerManager(enable_lora=True)
serving = ConcreteServingBase(tokenizer_manager)
# Should not raise
try:
serving._validate_lora_enabled("sql-expert")
except ValueError:
self.fail("_validate_lora_enabled raised ValueError unexpectedly")
def test_validation_fails_when_lora_disabled(self):
"""Test validation fails with helpful message when LoRA is disabled."""
tokenizer_manager = MockTokenizerManager(enable_lora=False)
serving = ConcreteServingBase(tokenizer_manager)
with self.assertRaises(ValueError) as context:
serving._validate_lora_enabled("sql-expert")
error_message = str(context.exception)
self.assertIn("sql-expert", error_message)
self.assertIn("--enable-lora", error_message)
self.assertIn("not enabled", error_message)
def test_validation_error_mentions_adapter_name(self):
"""Test that error message includes the requested adapter name."""
tokenizer_manager = MockTokenizerManager(enable_lora=False)
serving = ConcreteServingBase(tokenizer_manager)
with self.assertRaises(ValueError) as context:
serving._validate_lora_enabled("my-custom-adapter")
self.assertIn("my-custom-adapter", str(context.exception))
class TestIntegrationScenarios(unittest.TestCase):
"""Integration tests for common usage scenarios."""
def setUp(self):
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
self.serving = ConcreteServingBase(self.tokenizer_manager)
def test_openai_compatible_usage(self):
"""Test typical OpenAI-compatible usage pattern."""
# User specifies adapter in model parameter
model = "meta-llama/Llama-3.1-8B:sql-expert"
explicit_lora = None
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
self.assertEqual(lora_path, "sql-expert")
# Validation should pass
self.serving._validate_lora_enabled(lora_path)
def test_backward_compatible_usage(self):
"""Test backward-compatible usage with explicit lora_path."""
model = "meta-llama/Llama-3.1-8B"
explicit_lora = "sql-expert"
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
self.assertEqual(lora_path, "sql-expert")
# Validation should pass
self.serving._validate_lora_enabled(lora_path)
def test_base_model_usage(self):
"""Test using base model without any adapter."""
model = "meta-llama/Llama-3.1-8B"
explicit_lora = None
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
self.assertIsNone(lora_path)
# No validation needed when no adapter
def test_batch_request_scenario(self):
"""Test batch request with list of adapters."""
model = "meta-llama/Llama-3.1-8B" # No adapter in model
explicit_lora = ["sql-expert", "python-expert", None]
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
self.assertEqual(lora_path, explicit_lora)
# Validate first adapter in list
if isinstance(lora_path, list) and lora_path[0]:
self.serving._validate_lora_enabled(lora_path[0])
def test_adapter_in_model_overrides_batch_list(self):
"""Test that adapter in model parameter overrides batch list."""
model = "meta-llama/Llama-3.1-8B:preferred-adapter"
explicit_lora = ["adapter1", "adapter2"]
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
self.assertEqual(lora_path, "preferred-adapter")
def test_error_when_lora_not_enabled(self):
"""Test comprehensive error flow when LoRA is not enabled."""
# Setup server without LoRA enabled
tokenizer_manager = MockTokenizerManager(enable_lora=False)
serving = ConcreteServingBase(tokenizer_manager)
# User tries to use adapter
model = "meta-llama/Llama-3.1-8B:sql-expert"
lora_path = serving._resolve_lora_path(model, None)
# Should get helpful error
with self.assertRaises(ValueError) as context:
serving._validate_lora_enabled(lora_path)
error = str(context.exception)
self.assertIn("--enable-lora", error)
self.assertIn("sql-expert", error)
class TestEdgeCases(unittest.TestCase):
"""Test edge cases and error conditions."""
def setUp(self):
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
self.serving = ConcreteServingBase(self.tokenizer_manager)
def test_empty_string_model(self):
"""Test handling of empty string model."""
base, adapter = self.serving._parse_model_parameter("")
self.assertEqual(base, "")
self.assertIsNone(adapter)
def test_only_colon(self):
"""Test model parameter that is just a colon."""
base, adapter = self.serving._parse_model_parameter(":")
self.assertEqual(base, "")
self.assertIsNone(adapter)
def test_empty_list_lora_path(self):
"""Test validation with empty list doesn't crash."""
lora_path = self.serving._resolve_lora_path("model-name", [])
# Empty list is falsy, so validation won't be called
self.assertEqual(lora_path, [])
def test_list_with_none_first(self):
"""Test validation finds first non-None adapter in list."""
lora_path = self.serving._resolve_lora_path("model-name", [None, "adapter2"])
self.assertEqual(lora_path, [None, "adapter2"])
# In actual usage, validation would find "adapter2"
def test_list_all_none(self):
"""Test validation with list of all None values."""
lora_path = self.serving._resolve_lora_path("model-name", [None, None])
self.assertEqual(lora_path, [None, None])
# In actual usage, no validation would occur (no non-None adapters)
def test_unicode_in_adapter_name(self):
"""Test Unicode characters in adapter name."""
base, adapter = self.serving._parse_model_parameter("model:adapter-名前")
self.assertEqual(base, "model")
self.assertEqual(adapter, "adapter-名前")
def test_special_characters_in_adapter(self):
"""Test special characters in adapter name."""
base, adapter = self.serving._parse_model_parameter("model:adapter_v2.1-final")
self.assertEqual(base, "model")
self.assertEqual(adapter, "adapter_v2.1-final")
def test_none_as_explicit_lora_path(self):
"""Test None as explicit lora_path is handled correctly."""
result = self.serving._resolve_lora_path("model:adapter", None)
self.assertEqual(result, "adapter")
def test_empty_string_as_explicit_lora_path(self):
"""Test empty string as explicit lora_path."""
result = self.serving._resolve_lora_path("model-name", "")
self.assertEqual(result, "")
def test_validation_with_empty_adapter_name(self):
"""Test validation with empty adapter name still raises error."""
tokenizer_manager = MockTokenizerManager(enable_lora=False)
serving = ConcreteServingBase(tokenizer_manager)
with self.assertRaises(ValueError):
serving._validate_lora_enabled("")
if __name__ == "__main__":
unittest.main()