328 lines
13 KiB
Python
328 lines
13 KiB
Python
"""
|
|
Unit tests for OpenAI-compatible LoRA API support.
|
|
|
|
Tests the model parameter parsing and LoRA adapter resolution logic
|
|
that enables OpenAI-compatible LoRA adapter selection.
|
|
"""
|
|
|
|
import unittest
|
|
from unittest.mock import MagicMock, Mock
|
|
|
|
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
|
from sglang.srt.server_args import ServerArgs
|
|
|
|
|
|
class MockTokenizerManager:
|
|
"""Mock TokenizerManager for testing."""
|
|
|
|
def __init__(self, enable_lora=False):
|
|
self.server_args = MagicMock(spec=ServerArgs)
|
|
self.server_args.enable_lora = enable_lora
|
|
self.server_args.tokenizer_metrics_allowed_custom_labels = None
|
|
|
|
|
|
class ConcreteServingBase(OpenAIServingBase):
|
|
"""Concrete implementation for testing abstract base class."""
|
|
|
|
def _request_id_prefix(self) -> str:
|
|
return "test-"
|
|
|
|
def _convert_to_internal_request(self, request, raw_request=None):
|
|
pass
|
|
|
|
def _validate_request(self, request):
|
|
pass
|
|
|
|
|
|
class TestParseModelParameter(unittest.TestCase):
|
|
"""Test _parse_model_parameter method."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
|
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
|
|
|
def test_model_without_adapter(self):
|
|
"""Test parsing model without adapter returns None for adapter."""
|
|
base_model, adapter = self.serving._parse_model_parameter("llama-3.1-8B")
|
|
self.assertEqual(base_model, "llama-3.1-8B")
|
|
self.assertIsNone(adapter)
|
|
|
|
def test_model_with_adapter(self):
|
|
"""Test parsing model with adapter extracts both parts."""
|
|
base_model, adapter = self.serving._parse_model_parameter(
|
|
"llama-3.1-8B:sql-expert"
|
|
)
|
|
self.assertEqual(base_model, "llama-3.1-8B")
|
|
self.assertEqual(adapter, "sql-expert")
|
|
|
|
def test_model_with_path_and_adapter(self):
|
|
"""Test parsing model path with slashes and adapter."""
|
|
base_model, adapter = self.serving._parse_model_parameter(
|
|
"meta-llama/Llama-3.1-8B-Instruct:adapter-name"
|
|
)
|
|
self.assertEqual(base_model, "meta-llama/Llama-3.1-8B-Instruct")
|
|
self.assertEqual(adapter, "adapter-name")
|
|
|
|
def test_model_with_multiple_colons(self):
|
|
"""Test that only first colon is used for splitting."""
|
|
base_model, adapter = self.serving._parse_model_parameter("model:adapter:extra")
|
|
self.assertEqual(base_model, "model")
|
|
self.assertEqual(adapter, "adapter:extra")
|
|
|
|
def test_model_with_whitespace(self):
|
|
"""Test that whitespace is stripped from both parts."""
|
|
base_model, adapter = self.serving._parse_model_parameter(
|
|
" model-name : adapter-name "
|
|
)
|
|
self.assertEqual(base_model, "model-name")
|
|
self.assertEqual(adapter, "adapter-name")
|
|
|
|
def test_model_with_empty_adapter(self):
|
|
"""Test model ending with colon returns None for adapter."""
|
|
base_model, adapter = self.serving._parse_model_parameter("model-name:")
|
|
self.assertEqual(base_model, "model-name")
|
|
self.assertIsNone(adapter)
|
|
|
|
def test_model_with_only_spaces_after_colon(self):
|
|
"""Test model with only whitespace after colon returns None for adapter."""
|
|
base_model, adapter = self.serving._parse_model_parameter("model-name: ")
|
|
self.assertEqual(base_model, "model-name")
|
|
self.assertIsNone(adapter)
|
|
|
|
|
|
class TestResolveLoraPath(unittest.TestCase):
|
|
"""Test _resolve_lora_path method."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
|
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
|
|
|
def test_no_adapter_specified(self):
|
|
"""Test when neither model nor explicit lora_path has adapter."""
|
|
result = self.serving._resolve_lora_path("model-name", None)
|
|
self.assertIsNone(result)
|
|
|
|
def test_adapter_in_model_only(self):
|
|
"""Test adapter from model parameter when no explicit path."""
|
|
result = self.serving._resolve_lora_path("model:sql-expert", None)
|
|
self.assertEqual(result, "sql-expert")
|
|
|
|
def test_adapter_in_explicit_only(self):
|
|
"""Test adapter from explicit lora_path when not in model."""
|
|
result = self.serving._resolve_lora_path("model-name", "python-expert")
|
|
self.assertEqual(result, "python-expert")
|
|
|
|
def test_model_parameter_takes_precedence(self):
|
|
"""Test model parameter adapter takes precedence over explicit."""
|
|
result = self.serving._resolve_lora_path("model:sql-expert", "python-expert")
|
|
self.assertEqual(result, "sql-expert")
|
|
|
|
def test_with_list_explicit_lora_path(self):
|
|
"""Test that explicit list is returned when no model adapter."""
|
|
explicit = ["adapter1", "adapter2", None]
|
|
result = self.serving._resolve_lora_path("model-name", explicit)
|
|
self.assertEqual(result, explicit)
|
|
|
|
def test_model_adapter_overrides_list(self):
|
|
"""Test model adapter overrides even when explicit is a list."""
|
|
result = self.serving._resolve_lora_path(
|
|
"model:sql-expert", ["adapter1", "adapter2"]
|
|
)
|
|
self.assertEqual(result, "sql-expert")
|
|
|
|
def test_complex_model_name_with_adapter(self):
|
|
"""Test resolution with complex model name."""
|
|
result = self.serving._resolve_lora_path(
|
|
"org/model-v2.1:adapter-name", "other-adapter"
|
|
)
|
|
self.assertEqual(result, "adapter-name")
|
|
|
|
|
|
class TestValidateLoraEnabled(unittest.TestCase):
|
|
"""Test _validate_lora_enabled method."""
|
|
|
|
def test_validation_passes_when_lora_enabled(self):
|
|
"""Test validation passes when LoRA is enabled."""
|
|
tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
|
serving = ConcreteServingBase(tokenizer_manager)
|
|
|
|
# Should not raise
|
|
try:
|
|
serving._validate_lora_enabled("sql-expert")
|
|
except ValueError:
|
|
self.fail("_validate_lora_enabled raised ValueError unexpectedly")
|
|
|
|
def test_validation_fails_when_lora_disabled(self):
|
|
"""Test validation fails with helpful message when LoRA is disabled."""
|
|
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
|
serving = ConcreteServingBase(tokenizer_manager)
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
serving._validate_lora_enabled("sql-expert")
|
|
|
|
error_message = str(context.exception)
|
|
self.assertIn("sql-expert", error_message)
|
|
self.assertIn("--enable-lora", error_message)
|
|
self.assertIn("not enabled", error_message)
|
|
|
|
def test_validation_error_mentions_adapter_name(self):
|
|
"""Test that error message includes the requested adapter name."""
|
|
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
|
serving = ConcreteServingBase(tokenizer_manager)
|
|
|
|
with self.assertRaises(ValueError) as context:
|
|
serving._validate_lora_enabled("my-custom-adapter")
|
|
|
|
self.assertIn("my-custom-adapter", str(context.exception))
|
|
|
|
|
|
class TestIntegrationScenarios(unittest.TestCase):
|
|
"""Integration tests for common usage scenarios."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
|
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
|
|
|
def test_openai_compatible_usage(self):
|
|
"""Test typical OpenAI-compatible usage pattern."""
|
|
# User specifies adapter in model parameter
|
|
model = "meta-llama/Llama-3.1-8B:sql-expert"
|
|
explicit_lora = None
|
|
|
|
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
|
self.assertEqual(lora_path, "sql-expert")
|
|
|
|
# Validation should pass
|
|
self.serving._validate_lora_enabled(lora_path)
|
|
|
|
def test_backward_compatible_usage(self):
|
|
"""Test backward-compatible usage with explicit lora_path."""
|
|
model = "meta-llama/Llama-3.1-8B"
|
|
explicit_lora = "sql-expert"
|
|
|
|
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
|
self.assertEqual(lora_path, "sql-expert")
|
|
|
|
# Validation should pass
|
|
self.serving._validate_lora_enabled(lora_path)
|
|
|
|
def test_base_model_usage(self):
|
|
"""Test using base model without any adapter."""
|
|
model = "meta-llama/Llama-3.1-8B"
|
|
explicit_lora = None
|
|
|
|
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
|
self.assertIsNone(lora_path)
|
|
|
|
# No validation needed when no adapter
|
|
|
|
def test_batch_request_scenario(self):
|
|
"""Test batch request with list of adapters."""
|
|
model = "meta-llama/Llama-3.1-8B" # No adapter in model
|
|
explicit_lora = ["sql-expert", "python-expert", None]
|
|
|
|
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
|
self.assertEqual(lora_path, explicit_lora)
|
|
|
|
# Validate first adapter in list
|
|
if isinstance(lora_path, list) and lora_path[0]:
|
|
self.serving._validate_lora_enabled(lora_path[0])
|
|
|
|
def test_adapter_in_model_overrides_batch_list(self):
|
|
"""Test that adapter in model parameter overrides batch list."""
|
|
model = "meta-llama/Llama-3.1-8B:preferred-adapter"
|
|
explicit_lora = ["adapter1", "adapter2"]
|
|
|
|
lora_path = self.serving._resolve_lora_path(model, explicit_lora)
|
|
self.assertEqual(lora_path, "preferred-adapter")
|
|
|
|
def test_error_when_lora_not_enabled(self):
|
|
"""Test comprehensive error flow when LoRA is not enabled."""
|
|
# Setup server without LoRA enabled
|
|
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
|
serving = ConcreteServingBase(tokenizer_manager)
|
|
|
|
# User tries to use adapter
|
|
model = "meta-llama/Llama-3.1-8B:sql-expert"
|
|
lora_path = serving._resolve_lora_path(model, None)
|
|
|
|
# Should get helpful error
|
|
with self.assertRaises(ValueError) as context:
|
|
serving._validate_lora_enabled(lora_path)
|
|
|
|
error = str(context.exception)
|
|
self.assertIn("--enable-lora", error)
|
|
self.assertIn("sql-expert", error)
|
|
|
|
|
|
class TestEdgeCases(unittest.TestCase):
|
|
"""Test edge cases and error conditions."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer_manager = MockTokenizerManager(enable_lora=True)
|
|
self.serving = ConcreteServingBase(self.tokenizer_manager)
|
|
|
|
def test_empty_string_model(self):
|
|
"""Test handling of empty string model."""
|
|
base, adapter = self.serving._parse_model_parameter("")
|
|
self.assertEqual(base, "")
|
|
self.assertIsNone(adapter)
|
|
|
|
def test_only_colon(self):
|
|
"""Test model parameter that is just a colon."""
|
|
base, adapter = self.serving._parse_model_parameter(":")
|
|
self.assertEqual(base, "")
|
|
self.assertIsNone(adapter)
|
|
|
|
def test_empty_list_lora_path(self):
|
|
"""Test validation with empty list doesn't crash."""
|
|
lora_path = self.serving._resolve_lora_path("model-name", [])
|
|
# Empty list is falsy, so validation won't be called
|
|
self.assertEqual(lora_path, [])
|
|
|
|
def test_list_with_none_first(self):
|
|
"""Test validation finds first non-None adapter in list."""
|
|
lora_path = self.serving._resolve_lora_path("model-name", [None, "adapter2"])
|
|
self.assertEqual(lora_path, [None, "adapter2"])
|
|
# In actual usage, validation would find "adapter2"
|
|
|
|
def test_list_all_none(self):
|
|
"""Test validation with list of all None values."""
|
|
lora_path = self.serving._resolve_lora_path("model-name", [None, None])
|
|
self.assertEqual(lora_path, [None, None])
|
|
# In actual usage, no validation would occur (no non-None adapters)
|
|
|
|
def test_unicode_in_adapter_name(self):
|
|
"""Test Unicode characters in adapter name."""
|
|
base, adapter = self.serving._parse_model_parameter("model:adapter-名前")
|
|
self.assertEqual(base, "model")
|
|
self.assertEqual(adapter, "adapter-名前")
|
|
|
|
def test_special_characters_in_adapter(self):
|
|
"""Test special characters in adapter name."""
|
|
base, adapter = self.serving._parse_model_parameter("model:adapter_v2.1-final")
|
|
self.assertEqual(base, "model")
|
|
self.assertEqual(adapter, "adapter_v2.1-final")
|
|
|
|
def test_none_as_explicit_lora_path(self):
|
|
"""Test None as explicit lora_path is handled correctly."""
|
|
result = self.serving._resolve_lora_path("model:adapter", None)
|
|
self.assertEqual(result, "adapter")
|
|
|
|
def test_empty_string_as_explicit_lora_path(self):
|
|
"""Test empty string as explicit lora_path."""
|
|
result = self.serving._resolve_lora_path("model-name", "")
|
|
self.assertEqual(result, "")
|
|
|
|
def test_validation_with_empty_adapter_name(self):
|
|
"""Test validation with empty adapter name still raises error."""
|
|
tokenizer_manager = MockTokenizerManager(enable_lora=False)
|
|
serving = ConcreteServingBase(tokenizer_manager)
|
|
|
|
with self.assertRaises(ValueError):
|
|
serving._validate_lora_enabled("")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|