Files
sglang/test/srt/test_tokenizer_manager.py
2025-09-14 01:56:04 +08:00

380 lines
15 KiB
Python

"""
Unit tests for TokenizerManager helper methods.
This tests the refactored tokenization functionality including input format detection,
tokenizer input preparation, and result extraction logic.
Usage:
python3 -m unittest test_tokenizer_manager.TestInputFormatDetection
python3 -m unittest test_tokenizer_manager.TestTokenizerInputPreparation
python3 -m unittest test_tokenizer_manager.TestTokenizerResultExtraction
python3 -m unittest test_tokenizer_manager.TestTokenizerManagerIntegration
"""
import unittest
from typing import List, Optional, Union
from unittest.mock import Mock, patch
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestInputFormatDetection(unittest.TestCase):
"""Test cases for _detect_input_format method."""
def setUp(self):
"""Set up test fixtures."""
with patch("sglang.srt.utils.get_device", return_value="cpu"):
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
self.port_args = PortArgs.init_new(self.server_args)
with patch("zmq.asyncio.Context"), patch(
"sglang.srt.utils.get_zmq_socket"
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
mock_tokenizer.return_value = Mock(vocab_size=32000)
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
def test_detect_single_string(self):
"""Test detection of single string input."""
text = "Hello world"
result = self.tokenizer_manager._detect_input_format(
text, is_cross_encoder=False
)
self.assertEqual(result, "single_string")
def test_detect_single_string_cross_encoder_disabled(self):
"""Test single string with cross_encoder disabled still returns single_string."""
text = "Hello world"
result = self.tokenizer_manager._detect_input_format(
text, is_cross_encoder=True
)
self.assertEqual(result, "single_string")
def test_detect_batch_strings(self):
"""Test detection of batch string inputs."""
texts = ["Hello", "World", "How are you?"]
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=False
)
self.assertEqual(result, "batch_strings")
def test_detect_batch_strings_cross_encoder_disabled(self):
"""Test batch strings with cross_encoder disabled."""
texts = ["Hello", "World"]
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=True
)
self.assertEqual(result, "batch_strings")
def test_detect_cross_encoder_single_pair(self):
"""Test detection of cross-encoder single pair."""
texts = [["query text", "document text"]]
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=True
)
self.assertEqual(result, "cross_encoder_pairs")
def test_detect_cross_encoder_multiple_pairs(self):
"""Test detection of cross-encoder multiple pairs."""
texts = [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=True
)
self.assertEqual(result, "cross_encoder_pairs")
def test_detect_cross_encoder_disabled_with_pairs(self):
"""Test pairs with cross_encoder disabled should return batch_strings."""
texts = [["query", "document"]]
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=False
)
self.assertEqual(result, "batch_strings")
def test_detect_empty_list(self):
"""Test detection with empty list."""
texts = []
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=True
)
self.assertEqual(result, "batch_strings")
def test_detect_malformed_cross_encoder_pairs(self):
"""Test malformed cross-encoder pairs (not length 2)."""
texts = [["query only"]] # Single element, not a pair
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=True
)
self.assertEqual(result, "batch_strings")
texts = [["query", "doc", "extra"]] # Three elements, not a pair
result = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=True
)
self.assertEqual(result, "batch_strings")
class TestTokenizerInputPreparation(unittest.TestCase):
"""Test cases for _prepare_tokenizer_input method."""
def setUp(self):
"""Set up test fixtures."""
with patch("sglang.srt.utils.get_device", return_value="cpu"):
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
self.port_args = PortArgs.init_new(self.server_args)
with patch("zmq.asyncio.Context"), patch(
"sglang.srt.utils.get_zmq_socket"
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
mock_tokenizer.return_value = Mock(vocab_size=32000)
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
def test_prepare_single_string_input(self):
"""Test preparation of single string input."""
text = "Hello world"
result = self.tokenizer_manager._prepare_tokenizer_input(text, "single_string")
self.assertEqual(result, ["Hello world"])
def test_prepare_batch_strings_input(self):
"""Test preparation of batch strings input."""
texts = ["Hello", "World", "Test"]
result = self.tokenizer_manager._prepare_tokenizer_input(texts, "batch_strings")
self.assertEqual(result, ["Hello", "World", "Test"])
def test_prepare_cross_encoder_pairs_input(self):
"""Test preparation of cross-encoder pairs input."""
texts = [["query1", "doc1"], ["query2", "doc2"]]
result = self.tokenizer_manager._prepare_tokenizer_input(
texts, "cross_encoder_pairs"
)
self.assertEqual(result, [["query1", "doc1"], ["query2", "doc2"]])
def test_prepare_cross_encoder_single_pair_input(self):
"""Test preparation of single cross-encoder pair."""
texts = [["query text", "document text"]]
result = self.tokenizer_manager._prepare_tokenizer_input(
texts, "cross_encoder_pairs"
)
self.assertEqual(result, [["query text", "document text"]])
def test_prepare_unknown_input_format(self):
"""Test preparation with unknown input format falls back to returning as-is."""
texts = ["test"]
result = self.tokenizer_manager._prepare_tokenizer_input(
texts, "unknown_format"
)
self.assertEqual(result, ["test"])
class TestTokenizerResultExtraction(unittest.TestCase):
"""Test cases for _extract_tokenizer_results method."""
def setUp(self):
"""Set up test fixtures."""
with patch("sglang.srt.utils.get_device", return_value="cpu"):
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
self.port_args = PortArgs.init_new(self.server_args)
with patch("zmq.asyncio.Context"), patch(
"sglang.srt.utils.get_zmq_socket"
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
mock_tokenizer.return_value = Mock(vocab_size=32000)
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
def test_extract_single_string_results(self):
"""Test extraction for single string input."""
input_ids = [[101, 2129, 102]]
token_type_ids = [[0, 0, 0]]
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
input_ids, token_type_ids, "single_string", original_batch_size=1
)
)
self.assertEqual(result_input_ids, [101, 2129, 102])
self.assertEqual(result_token_type_ids, [0, 0, 0])
def test_extract_single_cross_encoder_results(self):
"""Test extraction for single cross-encoder pair."""
input_ids = [[101, 2129, 102, 4068, 102]]
token_type_ids = [[0, 0, 0, 1, 1]]
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=1
)
)
self.assertEqual(result_input_ids, [101, 2129, 102, 4068, 102])
self.assertEqual(result_token_type_ids, [0, 0, 0, 1, 1])
def test_extract_batch_results(self):
"""Test extraction for batch inputs."""
input_ids = [[101, 2129, 102], [101, 4068, 102]]
token_type_ids = [[0, 0, 0], [0, 0, 0]]
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
input_ids, token_type_ids, "batch_strings", original_batch_size=2
)
)
self.assertEqual(result_input_ids, [[101, 2129, 102], [101, 4068, 102]])
self.assertEqual(result_token_type_ids, [[0, 0, 0], [0, 0, 0]])
def test_extract_multiple_cross_encoder_results(self):
"""Test extraction for multiple cross-encoder pairs."""
input_ids = [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
token_type_ids = [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]]
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=2
)
)
self.assertEqual(
result_input_ids, [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
)
self.assertEqual(result_token_type_ids, [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]])
def test_extract_empty_results(self):
"""Test extraction with empty results."""
input_ids = []
token_type_ids = None
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
input_ids, token_type_ids, "single_string", original_batch_size=1
)
)
self.assertEqual(result_input_ids, [])
self.assertIsNone(result_token_type_ids)
def test_extract_with_none_token_type_ids(self):
"""Test extraction when token_type_ids is None."""
input_ids = [[101, 2129, 102]]
token_type_ids = None
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
input_ids, token_type_ids, "single_string", original_batch_size=1
)
)
self.assertEqual(result_input_ids, [101, 2129, 102])
self.assertIsNone(result_token_type_ids)
class TestTokenizerManagerIntegration(unittest.TestCase):
"""Integration tests combining multiple helper methods."""
def setUp(self):
"""Set up test fixtures."""
with patch("sglang.srt.utils.get_device", return_value="cpu"):
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
self.port_args = PortArgs.init_new(self.server_args)
with patch("zmq.asyncio.Context"), patch(
"sglang.srt.utils.get_zmq_socket"
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
mock_tokenizer.return_value = Mock(vocab_size=32000)
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
def test_full_workflow_single_string(self):
"""Test complete workflow for single string input."""
text = "Hello world"
# Step 1: Detect format
input_format = self.tokenizer_manager._detect_input_format(
text, is_cross_encoder=False
)
self.assertEqual(input_format, "single_string")
# Step 2: Prepare input
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
text, input_format
)
self.assertEqual(tokenizer_input, ["Hello world"])
# Step 3: Extract results (simulated tokenizer output)
mock_input_ids = [[101, 2129, 4248, 102]]
mock_token_type_ids = None
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
)
)
self.assertEqual(result_input_ids, [101, 2129, 4248, 102])
self.assertIsNone(result_token_type_ids)
def test_full_workflow_cross_encoder_pairs(self):
"""Test complete workflow for cross-encoder pairs."""
texts = [
["How many people live in Berlin?", "Berlin is well known for its museums."]
]
# Step 1: Detect format
input_format = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=True
)
self.assertEqual(input_format, "cross_encoder_pairs")
# Step 2: Prepare input
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
texts, input_format
)
self.assertEqual(tokenizer_input, texts)
# Step 3: Extract results (simulated tokenizer output for cross-encoder)
mock_input_ids = [[101, 2129, 2116, 102, 4068, 2003, 102]]
mock_token_type_ids = [[0, 0, 0, 0, 1, 1, 1]]
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
)
)
self.assertEqual(result_input_ids, [101, 2129, 2116, 102, 4068, 2003, 102])
self.assertEqual(result_token_type_ids, [0, 0, 0, 0, 1, 1, 1])
def test_full_workflow_batch_strings(self):
"""Test complete workflow for batch strings."""
texts = ["Hello", "World", "Test"]
# Step 1: Detect format
input_format = self.tokenizer_manager._detect_input_format(
texts, is_cross_encoder=False
)
self.assertEqual(input_format, "batch_strings")
# Step 2: Prepare input
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
texts, input_format
)
self.assertEqual(tokenizer_input, ["Hello", "World", "Test"])
# Step 3: Extract results (simulated tokenizer output)
mock_input_ids = [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
mock_token_type_ids = None
result_input_ids, result_token_type_ids = (
self.tokenizer_manager._extract_tokenizer_results(
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=3
)
)
self.assertEqual(
result_input_ids, [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
)
self.assertIsNone(result_token_type_ids)
if __name__ == "__main__":
unittest.main(verbosity=2)