380 lines
15 KiB
Python
380 lines
15 KiB
Python
"""
|
|
Unit tests for TokenizerManager helper methods.
|
|
|
|
This tests the refactored tokenization functionality including input format detection,
|
|
tokenizer input preparation, and result extraction logic.
|
|
|
|
Usage:
|
|
python3 -m unittest test_tokenizer_manager.TestInputFormatDetection
|
|
python3 -m unittest test_tokenizer_manager.TestTokenizerInputPreparation
|
|
python3 -m unittest test_tokenizer_manager.TestTokenizerResultExtraction
|
|
python3 -m unittest test_tokenizer_manager.TestTokenizerManagerIntegration
|
|
"""
|
|
|
|
import unittest
|
|
from typing import List, Optional, Union
|
|
from unittest.mock import Mock, patch
|
|
|
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
|
|
|
|
class TestInputFormatDetection(unittest.TestCase):
|
|
"""Test cases for _detect_input_format method."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
|
self.port_args = PortArgs.init_new(self.server_args)
|
|
|
|
with patch("zmq.asyncio.Context"), patch(
|
|
"sglang.srt.utils.get_zmq_socket"
|
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
|
|
|
def test_detect_single_string(self):
|
|
"""Test detection of single string input."""
|
|
text = "Hello world"
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
text, is_cross_encoder=False
|
|
)
|
|
self.assertEqual(result, "single_string")
|
|
|
|
def test_detect_single_string_cross_encoder_disabled(self):
|
|
"""Test single string with cross_encoder disabled still returns single_string."""
|
|
text = "Hello world"
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
text, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(result, "single_string")
|
|
|
|
def test_detect_batch_strings(self):
|
|
"""Test detection of batch string inputs."""
|
|
texts = ["Hello", "World", "How are you?"]
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=False
|
|
)
|
|
self.assertEqual(result, "batch_strings")
|
|
|
|
def test_detect_batch_strings_cross_encoder_disabled(self):
|
|
"""Test batch strings with cross_encoder disabled."""
|
|
texts = ["Hello", "World"]
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(result, "batch_strings")
|
|
|
|
def test_detect_cross_encoder_single_pair(self):
|
|
"""Test detection of cross-encoder single pair."""
|
|
texts = [["query text", "document text"]]
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(result, "cross_encoder_pairs")
|
|
|
|
def test_detect_cross_encoder_multiple_pairs(self):
|
|
"""Test detection of cross-encoder multiple pairs."""
|
|
texts = [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(result, "cross_encoder_pairs")
|
|
|
|
def test_detect_cross_encoder_disabled_with_pairs(self):
|
|
"""Test pairs with cross_encoder disabled should return batch_strings."""
|
|
texts = [["query", "document"]]
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=False
|
|
)
|
|
self.assertEqual(result, "batch_strings")
|
|
|
|
def test_detect_empty_list(self):
|
|
"""Test detection with empty list."""
|
|
texts = []
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(result, "batch_strings")
|
|
|
|
def test_detect_malformed_cross_encoder_pairs(self):
|
|
"""Test malformed cross-encoder pairs (not length 2)."""
|
|
texts = [["query only"]] # Single element, not a pair
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(result, "batch_strings")
|
|
|
|
texts = [["query", "doc", "extra"]] # Three elements, not a pair
|
|
result = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(result, "batch_strings")
|
|
|
|
|
|
class TestTokenizerInputPreparation(unittest.TestCase):
|
|
"""Test cases for _prepare_tokenizer_input method."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
|
self.port_args = PortArgs.init_new(self.server_args)
|
|
|
|
with patch("zmq.asyncio.Context"), patch(
|
|
"sglang.srt.utils.get_zmq_socket"
|
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
|
|
|
def test_prepare_single_string_input(self):
|
|
"""Test preparation of single string input."""
|
|
text = "Hello world"
|
|
result = self.tokenizer_manager._prepare_tokenizer_input(text, "single_string")
|
|
self.assertEqual(result, ["Hello world"])
|
|
|
|
def test_prepare_batch_strings_input(self):
|
|
"""Test preparation of batch strings input."""
|
|
texts = ["Hello", "World", "Test"]
|
|
result = self.tokenizer_manager._prepare_tokenizer_input(texts, "batch_strings")
|
|
self.assertEqual(result, ["Hello", "World", "Test"])
|
|
|
|
def test_prepare_cross_encoder_pairs_input(self):
|
|
"""Test preparation of cross-encoder pairs input."""
|
|
texts = [["query1", "doc1"], ["query2", "doc2"]]
|
|
result = self.tokenizer_manager._prepare_tokenizer_input(
|
|
texts, "cross_encoder_pairs"
|
|
)
|
|
self.assertEqual(result, [["query1", "doc1"], ["query2", "doc2"]])
|
|
|
|
def test_prepare_cross_encoder_single_pair_input(self):
|
|
"""Test preparation of single cross-encoder pair."""
|
|
texts = [["query text", "document text"]]
|
|
result = self.tokenizer_manager._prepare_tokenizer_input(
|
|
texts, "cross_encoder_pairs"
|
|
)
|
|
self.assertEqual(result, [["query text", "document text"]])
|
|
|
|
def test_prepare_unknown_input_format(self):
|
|
"""Test preparation with unknown input format falls back to returning as-is."""
|
|
texts = ["test"]
|
|
result = self.tokenizer_manager._prepare_tokenizer_input(
|
|
texts, "unknown_format"
|
|
)
|
|
self.assertEqual(result, ["test"])
|
|
|
|
|
|
class TestTokenizerResultExtraction(unittest.TestCase):
|
|
"""Test cases for _extract_tokenizer_results method."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
|
self.port_args = PortArgs.init_new(self.server_args)
|
|
|
|
with patch("zmq.asyncio.Context"), patch(
|
|
"sglang.srt.utils.get_zmq_socket"
|
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
|
|
|
def test_extract_single_string_results(self):
|
|
"""Test extraction for single string input."""
|
|
input_ids = [[101, 2129, 102]]
|
|
token_type_ids = [[0, 0, 0]]
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
input_ids, token_type_ids, "single_string", original_batch_size=1
|
|
)
|
|
)
|
|
|
|
self.assertEqual(result_input_ids, [101, 2129, 102])
|
|
self.assertEqual(result_token_type_ids, [0, 0, 0])
|
|
|
|
def test_extract_single_cross_encoder_results(self):
|
|
"""Test extraction for single cross-encoder pair."""
|
|
input_ids = [[101, 2129, 102, 4068, 102]]
|
|
token_type_ids = [[0, 0, 0, 1, 1]]
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=1
|
|
)
|
|
)
|
|
|
|
self.assertEqual(result_input_ids, [101, 2129, 102, 4068, 102])
|
|
self.assertEqual(result_token_type_ids, [0, 0, 0, 1, 1])
|
|
|
|
def test_extract_batch_results(self):
|
|
"""Test extraction for batch inputs."""
|
|
input_ids = [[101, 2129, 102], [101, 4068, 102]]
|
|
token_type_ids = [[0, 0, 0], [0, 0, 0]]
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
input_ids, token_type_ids, "batch_strings", original_batch_size=2
|
|
)
|
|
)
|
|
|
|
self.assertEqual(result_input_ids, [[101, 2129, 102], [101, 4068, 102]])
|
|
self.assertEqual(result_token_type_ids, [[0, 0, 0], [0, 0, 0]])
|
|
|
|
def test_extract_multiple_cross_encoder_results(self):
|
|
"""Test extraction for multiple cross-encoder pairs."""
|
|
input_ids = [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
|
|
token_type_ids = [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]]
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=2
|
|
)
|
|
)
|
|
|
|
self.assertEqual(
|
|
result_input_ids, [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
|
|
)
|
|
self.assertEqual(result_token_type_ids, [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]])
|
|
|
|
def test_extract_empty_results(self):
|
|
"""Test extraction with empty results."""
|
|
input_ids = []
|
|
token_type_ids = None
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
input_ids, token_type_ids, "single_string", original_batch_size=1
|
|
)
|
|
)
|
|
|
|
self.assertEqual(result_input_ids, [])
|
|
self.assertIsNone(result_token_type_ids)
|
|
|
|
def test_extract_with_none_token_type_ids(self):
|
|
"""Test extraction when token_type_ids is None."""
|
|
input_ids = [[101, 2129, 102]]
|
|
token_type_ids = None
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
input_ids, token_type_ids, "single_string", original_batch_size=1
|
|
)
|
|
)
|
|
|
|
self.assertEqual(result_input_ids, [101, 2129, 102])
|
|
self.assertIsNone(result_token_type_ids)
|
|
|
|
|
|
class TestTokenizerManagerIntegration(unittest.TestCase):
|
|
"""Integration tests combining multiple helper methods."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
|
self.port_args = PortArgs.init_new(self.server_args)
|
|
|
|
with patch("zmq.asyncio.Context"), patch(
|
|
"sglang.srt.utils.get_zmq_socket"
|
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
|
|
|
def test_full_workflow_single_string(self):
|
|
"""Test complete workflow for single string input."""
|
|
text = "Hello world"
|
|
|
|
# Step 1: Detect format
|
|
input_format = self.tokenizer_manager._detect_input_format(
|
|
text, is_cross_encoder=False
|
|
)
|
|
self.assertEqual(input_format, "single_string")
|
|
|
|
# Step 2: Prepare input
|
|
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
|
text, input_format
|
|
)
|
|
self.assertEqual(tokenizer_input, ["Hello world"])
|
|
|
|
# Step 3: Extract results (simulated tokenizer output)
|
|
mock_input_ids = [[101, 2129, 4248, 102]]
|
|
mock_token_type_ids = None
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
|
|
)
|
|
)
|
|
|
|
self.assertEqual(result_input_ids, [101, 2129, 4248, 102])
|
|
self.assertIsNone(result_token_type_ids)
|
|
|
|
def test_full_workflow_cross_encoder_pairs(self):
|
|
"""Test complete workflow for cross-encoder pairs."""
|
|
texts = [
|
|
["How many people live in Berlin?", "Berlin is well known for its museums."]
|
|
]
|
|
|
|
# Step 1: Detect format
|
|
input_format = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=True
|
|
)
|
|
self.assertEqual(input_format, "cross_encoder_pairs")
|
|
|
|
# Step 2: Prepare input
|
|
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
|
texts, input_format
|
|
)
|
|
self.assertEqual(tokenizer_input, texts)
|
|
|
|
# Step 3: Extract results (simulated tokenizer output for cross-encoder)
|
|
mock_input_ids = [[101, 2129, 2116, 102, 4068, 2003, 102]]
|
|
mock_token_type_ids = [[0, 0, 0, 0, 1, 1, 1]]
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
|
|
)
|
|
)
|
|
|
|
self.assertEqual(result_input_ids, [101, 2129, 2116, 102, 4068, 2003, 102])
|
|
self.assertEqual(result_token_type_ids, [0, 0, 0, 0, 1, 1, 1])
|
|
|
|
def test_full_workflow_batch_strings(self):
|
|
"""Test complete workflow for batch strings."""
|
|
texts = ["Hello", "World", "Test"]
|
|
|
|
# Step 1: Detect format
|
|
input_format = self.tokenizer_manager._detect_input_format(
|
|
texts, is_cross_encoder=False
|
|
)
|
|
self.assertEqual(input_format, "batch_strings")
|
|
|
|
# Step 2: Prepare input
|
|
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
|
texts, input_format
|
|
)
|
|
self.assertEqual(tokenizer_input, ["Hello", "World", "Test"])
|
|
|
|
# Step 3: Extract results (simulated tokenizer output)
|
|
mock_input_ids = [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
|
|
mock_token_type_ids = None
|
|
|
|
result_input_ids, result_token_type_ids = (
|
|
self.tokenizer_manager._extract_tokenizer_results(
|
|
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=3
|
|
)
|
|
)
|
|
|
|
self.assertEqual(
|
|
result_input_ids, [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
|
|
)
|
|
self.assertIsNone(result_token_type_ids)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|