[Performance] Dynamic Batch Tokenizer (#9382)
This commit is contained in:
committed by
GitHub
parent
eca59f96c3
commit
94d0f656fb
295
test/srt/test_async_dynamic_batch_tokenizer.py
Normal file
295
test/srt/test_async_dynamic_batch_tokenizer.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""
|
||||
Unit tests for AsyncDynamicbatchTokenizer.
|
||||
|
||||
Tests the async dynamic batching functionality for tokenization,
|
||||
including batch efficiency, timeout handling, and error cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
||||
|
||||
|
||||
class TestAsyncDynamicbatchTokenizer:
|
||||
"""Test suite for AsyncDynamicbatchTokenizer."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a mock tokenizer that behaves like HuggingFace tokenizer."""
|
||||
|
||||
def mock_encode(texts, **kwargs):
|
||||
is_single = isinstance(texts, str)
|
||||
if is_single:
|
||||
texts = [texts]
|
||||
|
||||
# Simulate tokenization - convert text to mock token ids
|
||||
input_ids = []
|
||||
token_type_ids = []
|
||||
|
||||
for text in texts:
|
||||
# Simple mock: text length determines number of tokens
|
||||
tokens = [i for i in range(len(text.split()))]
|
||||
input_ids.append(tokens)
|
||||
|
||||
if kwargs.get("return_token_type_ids", False):
|
||||
token_type_ids.append([0] * len(tokens))
|
||||
|
||||
result = {"input_ids": input_ids}
|
||||
if kwargs.get("return_token_type_ids", False):
|
||||
result["token_type_ids"] = token_type_ids
|
||||
|
||||
# For single inputs, return individual result (not wrapped in a list)
|
||||
if is_single:
|
||||
result = {"input_ids": input_ids[0]}
|
||||
if kwargs.get("return_token_type_ids", False):
|
||||
result["token_type_ids"] = token_type_ids[0]
|
||||
|
||||
# Create a proper BatchEncoding-like object that supports dict operations
|
||||
class MockBatchEncoding(dict):
|
||||
def __init__(self, data):
|
||||
super().__init__(data)
|
||||
for key, value in data.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
return MockBatchEncoding(result)
|
||||
|
||||
# Return the function directly - the AsyncDynamicbatchTokenizer will call it
|
||||
return mock_encode
|
||||
|
||||
@pytest.fixture
|
||||
def async_tokenizer(self, mock_tokenizer):
|
||||
"""Create AsyncDynamicbatchTokenizer instance."""
|
||||
return AsyncDynamicbatchTokenizer(
|
||||
tokenizer=mock_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_request(self, async_tokenizer):
|
||||
"""Test tokenizing a single request."""
|
||||
text = "hello world"
|
||||
result = await async_tokenizer.encode(text)
|
||||
|
||||
assert "input_ids" in result
|
||||
assert result["input_ids"] == [0, 1] # 2 words -> 2 tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_request_with_token_type_ids(self, async_tokenizer):
|
||||
"""Test tokenizing with token type IDs."""
|
||||
text = "hello world"
|
||||
result = await async_tokenizer.encode(text, return_token_type_ids=True)
|
||||
|
||||
assert "input_ids" in result
|
||||
assert "token_type_ids" in result
|
||||
assert result["input_ids"] == [0, 1]
|
||||
assert result["token_type_ids"] == [0, 0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_same_kwargs(self, async_tokenizer):
|
||||
"""Test that concurrent requests with same kwargs get batched."""
|
||||
texts = ["hello world", "how are you", "fine thanks", "good morning"]
|
||||
|
||||
# Start all requests concurrently
|
||||
tasks = [async_tokenizer.encode(text) for text in texts]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all results
|
||||
assert len(results) == 4
|
||||
for i, result in enumerate(results):
|
||||
assert "input_ids" in result
|
||||
expected_tokens = list(range(len(texts[i].split())))
|
||||
assert result["input_ids"] == expected_tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_requests_different_kwargs(self, async_tokenizer):
|
||||
"""Test that requests with different kwargs are processed individually."""
|
||||
text1 = "hello world"
|
||||
text2 = "how are you"
|
||||
|
||||
# One with token_type_ids, one without
|
||||
task1 = async_tokenizer.encode(text1, return_token_type_ids=True)
|
||||
task2 = async_tokenizer.encode(text2)
|
||||
|
||||
result1, result2 = await asyncio.gather(task1, task2)
|
||||
|
||||
# First result should have token_type_ids
|
||||
assert "input_ids" in result1
|
||||
assert "token_type_ids" in result1
|
||||
assert result1["input_ids"] == [0, 1]
|
||||
assert result1["token_type_ids"] == [0, 0]
|
||||
|
||||
# Second result should not have token_type_ids
|
||||
assert "input_ids" in result2
|
||||
assert "token_type_ids" not in result2
|
||||
assert result2["input_ids"] == [0, 1, 2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_timeout(self, async_tokenizer):
|
||||
"""Test that batching respects timeout."""
|
||||
# Send first request
|
||||
task1 = asyncio.create_task(async_tokenizer.encode("hello world"))
|
||||
|
||||
# Wait longer than batch timeout
|
||||
await asyncio.sleep(0.02) # Longer than 0.01s timeout
|
||||
|
||||
# Send second request
|
||||
task2 = asyncio.create_task(async_tokenizer.encode("how are you"))
|
||||
|
||||
results = await asyncio.gather(task1, task2)
|
||||
|
||||
# Both should complete successfully
|
||||
assert len(results) == 2
|
||||
assert results[0]["input_ids"] == [0, 1]
|
||||
assert results[1]["input_ids"] == [0, 1, 2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_batch_size_limit(self, async_tokenizer):
|
||||
"""Test that batching respects max_batch_size."""
|
||||
# Send more requests than max_batch_size (4)
|
||||
texts = [f"text {i}" for i in range(6)]
|
||||
tasks = [async_tokenizer.encode(text) for text in texts]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All should complete successfully
|
||||
assert len(results) == 6
|
||||
for i, result in enumerate(results):
|
||||
assert "input_ids" in result
|
||||
assert result["input_ids"] == [0, 1] # "text i" -> 2 tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callable_interface(self, async_tokenizer):
|
||||
"""Test that the tokenizer is callable."""
|
||||
text = "hello world"
|
||||
result = await async_tokenizer(text)
|
||||
|
||||
assert "input_ids" in result
|
||||
assert result["input_ids"] == [0, 1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lazy_initialization(self, mock_tokenizer):
|
||||
"""Test that initialization happens lazily."""
|
||||
tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer)
|
||||
|
||||
# Should not be initialized yet
|
||||
assert not tokenizer._initialized
|
||||
|
||||
# First encode should initialize
|
||||
await tokenizer.encode("hello")
|
||||
|
||||
# Should now be initialized
|
||||
assert tokenizer._initialized
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_in_tokenizer(self, mock_tokenizer):
|
||||
"""Test error handling when tokenizer fails."""
|
||||
|
||||
# Create a new async tokenizer with a failing tokenizer
|
||||
def failing_tokenizer(*args, **kwargs):
|
||||
raise ValueError("Tokenizer error")
|
||||
|
||||
async_tokenizer = AsyncDynamicbatchTokenizer(
|
||||
tokenizer=failing_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Tokenizer error"):
|
||||
await async_tokenizer.encode("hello world")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_processing_logs(self, async_tokenizer, caplog):
|
||||
"""Test that batch processing logs are generated."""
|
||||
caplog.set_level(logging.DEBUG)
|
||||
|
||||
# Send multiple requests to trigger batching
|
||||
tasks = [
|
||||
async_tokenizer.encode("hello world"),
|
||||
async_tokenizer.encode("how are you"),
|
||||
]
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# Should have batch processing log
|
||||
assert any(
|
||||
"Processing dynamic batch of size" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_queue_immediate_processing(self, async_tokenizer):
|
||||
"""Test that single requests are processed immediately when queue is empty."""
|
||||
start_time = time.time()
|
||||
result = await async_tokenizer.encode("hello world")
|
||||
end_time = time.time()
|
||||
|
||||
# Should complete quickly (much less than batch timeout)
|
||||
assert end_time - start_time < 0.005 # 5ms should be plenty
|
||||
assert result["input_ids"] == [0, 1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_tokenizer_integration(self):
|
||||
"""Test with a real HuggingFace tokenizer."""
|
||||
try:
|
||||
# Use a small, fast tokenizer for testing
|
||||
real_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
async_tokenizer = AsyncDynamicbatchTokenizer(
|
||||
tokenizer=real_tokenizer, max_batch_size=2, batch_wait_timeout_s=0.01
|
||||
)
|
||||
|
||||
text = "Hello, world!"
|
||||
result = await async_tokenizer.encode(text)
|
||||
|
||||
# Should get actual token IDs
|
||||
assert "input_ids" in result
|
||||
assert isinstance(result["input_ids"], list)
|
||||
assert len(result["input_ids"]) > 0
|
||||
assert all(isinstance(token_id, int) for token_id in result["input_ids"])
|
||||
|
||||
except Exception as e:
|
||||
pytest.skip(f"Real tokenizer test skipped: {e}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mixed_requests(self, async_tokenizer):
|
||||
"""Test mixing single and batched requests."""
|
||||
# Start some requests
|
||||
task1 = asyncio.create_task(async_tokenizer.encode("hello"))
|
||||
task2 = asyncio.create_task(async_tokenizer.encode("world"))
|
||||
|
||||
# Wait a bit
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
# Start more requests
|
||||
task3 = asyncio.create_task(async_tokenizer.encode("how are"))
|
||||
task4 = asyncio.create_task(async_tokenizer.encode("you doing"))
|
||||
|
||||
results = await asyncio.gather(task1, task2, task3, task4)
|
||||
|
||||
# All should complete successfully
|
||||
assert len(results) == 4
|
||||
for result in results:
|
||||
assert "input_ids" in result
|
||||
assert isinstance(result["input_ids"], list)
|
||||
|
||||
def test_cleanup_on_destruction(self, mock_tokenizer):
|
||||
"""Test that resources are cleaned up properly."""
|
||||
tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer)
|
||||
|
||||
# Mock the executor and task
|
||||
tokenizer._executor = Mock()
|
||||
tokenizer._batcher_task = Mock()
|
||||
tokenizer._batcher_task.done.return_value = False
|
||||
|
||||
# Call destructor
|
||||
tokenizer.__del__()
|
||||
|
||||
# Should cancel task and shutdown executor
|
||||
tokenizer._batcher_task.cancel.assert_called_once()
|
||||
tokenizer._executor.shutdown.assert_called_once_with(wait=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
379
test/srt/test_tokenizer_manager.py
Normal file
379
test/srt/test_tokenizer_manager.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Unit tests for TokenizerManager helper methods.
|
||||
|
||||
This tests the refactored tokenization functionality including input format detection,
|
||||
tokenizer input preparation, and result extraction logic.
|
||||
|
||||
Usage:
|
||||
python3 -m unittest test_tokenizer_manager.TestInputFormatDetection
|
||||
python3 -m unittest test_tokenizer_manager.TestTokenizerInputPreparation
|
||||
python3 -m unittest test_tokenizer_manager.TestTokenizerResultExtraction
|
||||
python3 -m unittest test_tokenizer_manager.TestTokenizerManagerIntegration
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from typing import List, Optional, Union
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
|
||||
class TestInputFormatDetection(unittest.TestCase):
|
||||
"""Test cases for _detect_input_format method."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
self.port_args = PortArgs.init_new(self.server_args)
|
||||
|
||||
with patch("zmq.asyncio.Context"), patch(
|
||||
"sglang.srt.utils.get_zmq_socket"
|
||||
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||
|
||||
def test_detect_single_string(self):
|
||||
"""Test detection of single string input."""
|
||||
text = "Hello world"
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
text, is_cross_encoder=False
|
||||
)
|
||||
self.assertEqual(result, "single_string")
|
||||
|
||||
def test_detect_single_string_cross_encoder_disabled(self):
|
||||
"""Test single string with cross_encoder disabled still returns single_string."""
|
||||
text = "Hello world"
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
text, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(result, "single_string")
|
||||
|
||||
def test_detect_batch_strings(self):
|
||||
"""Test detection of batch string inputs."""
|
||||
texts = ["Hello", "World", "How are you?"]
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=False
|
||||
)
|
||||
self.assertEqual(result, "batch_strings")
|
||||
|
||||
def test_detect_batch_strings_cross_encoder_disabled(self):
|
||||
"""Test batch strings with cross_encoder disabled."""
|
||||
texts = ["Hello", "World"]
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(result, "batch_strings")
|
||||
|
||||
def test_detect_cross_encoder_single_pair(self):
|
||||
"""Test detection of cross-encoder single pair."""
|
||||
texts = [["query text", "document text"]]
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(result, "cross_encoder_pairs")
|
||||
|
||||
def test_detect_cross_encoder_multiple_pairs(self):
|
||||
"""Test detection of cross-encoder multiple pairs."""
|
||||
texts = [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(result, "cross_encoder_pairs")
|
||||
|
||||
def test_detect_cross_encoder_disabled_with_pairs(self):
|
||||
"""Test pairs with cross_encoder disabled should return batch_strings."""
|
||||
texts = [["query", "document"]]
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=False
|
||||
)
|
||||
self.assertEqual(result, "batch_strings")
|
||||
|
||||
def test_detect_empty_list(self):
|
||||
"""Test detection with empty list."""
|
||||
texts = []
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(result, "batch_strings")
|
||||
|
||||
def test_detect_malformed_cross_encoder_pairs(self):
|
||||
"""Test malformed cross-encoder pairs (not length 2)."""
|
||||
texts = [["query only"]] # Single element, not a pair
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(result, "batch_strings")
|
||||
|
||||
texts = [["query", "doc", "extra"]] # Three elements, not a pair
|
||||
result = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(result, "batch_strings")
|
||||
|
||||
|
||||
class TestTokenizerInputPreparation(unittest.TestCase):
|
||||
"""Test cases for _prepare_tokenizer_input method."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
self.port_args = PortArgs.init_new(self.server_args)
|
||||
|
||||
with patch("zmq.asyncio.Context"), patch(
|
||||
"sglang.srt.utils.get_zmq_socket"
|
||||
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||
|
||||
def test_prepare_single_string_input(self):
|
||||
"""Test preparation of single string input."""
|
||||
text = "Hello world"
|
||||
result = self.tokenizer_manager._prepare_tokenizer_input(text, "single_string")
|
||||
self.assertEqual(result, ["Hello world"])
|
||||
|
||||
def test_prepare_batch_strings_input(self):
|
||||
"""Test preparation of batch strings input."""
|
||||
texts = ["Hello", "World", "Test"]
|
||||
result = self.tokenizer_manager._prepare_tokenizer_input(texts, "batch_strings")
|
||||
self.assertEqual(result, ["Hello", "World", "Test"])
|
||||
|
||||
def test_prepare_cross_encoder_pairs_input(self):
|
||||
"""Test preparation of cross-encoder pairs input."""
|
||||
texts = [["query1", "doc1"], ["query2", "doc2"]]
|
||||
result = self.tokenizer_manager._prepare_tokenizer_input(
|
||||
texts, "cross_encoder_pairs"
|
||||
)
|
||||
self.assertEqual(result, [["query1", "doc1"], ["query2", "doc2"]])
|
||||
|
||||
def test_prepare_cross_encoder_single_pair_input(self):
|
||||
"""Test preparation of single cross-encoder pair."""
|
||||
texts = [["query text", "document text"]]
|
||||
result = self.tokenizer_manager._prepare_tokenizer_input(
|
||||
texts, "cross_encoder_pairs"
|
||||
)
|
||||
self.assertEqual(result, [["query text", "document text"]])
|
||||
|
||||
def test_prepare_unknown_input_format(self):
|
||||
"""Test preparation with unknown input format falls back to returning as-is."""
|
||||
texts = ["test"]
|
||||
result = self.tokenizer_manager._prepare_tokenizer_input(
|
||||
texts, "unknown_format"
|
||||
)
|
||||
self.assertEqual(result, ["test"])
|
||||
|
||||
|
||||
class TestTokenizerResultExtraction(unittest.TestCase):
|
||||
"""Test cases for _extract_tokenizer_results method."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
self.port_args = PortArgs.init_new(self.server_args)
|
||||
|
||||
with patch("zmq.asyncio.Context"), patch(
|
||||
"sglang.srt.utils.get_zmq_socket"
|
||||
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||
|
||||
def test_extract_single_string_results(self):
|
||||
"""Test extraction for single string input."""
|
||||
input_ids = [[101, 2129, 102]]
|
||||
token_type_ids = [[0, 0, 0]]
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
input_ids, token_type_ids, "single_string", original_batch_size=1
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result_input_ids, [101, 2129, 102])
|
||||
self.assertEqual(result_token_type_ids, [0, 0, 0])
|
||||
|
||||
def test_extract_single_cross_encoder_results(self):
|
||||
"""Test extraction for single cross-encoder pair."""
|
||||
input_ids = [[101, 2129, 102, 4068, 102]]
|
||||
token_type_ids = [[0, 0, 0, 1, 1]]
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=1
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result_input_ids, [101, 2129, 102, 4068, 102])
|
||||
self.assertEqual(result_token_type_ids, [0, 0, 0, 1, 1])
|
||||
|
||||
def test_extract_batch_results(self):
|
||||
"""Test extraction for batch inputs."""
|
||||
input_ids = [[101, 2129, 102], [101, 4068, 102]]
|
||||
token_type_ids = [[0, 0, 0], [0, 0, 0]]
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
input_ids, token_type_ids, "batch_strings", original_batch_size=2
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result_input_ids, [[101, 2129, 102], [101, 4068, 102]])
|
||||
self.assertEqual(result_token_type_ids, [[0, 0, 0], [0, 0, 0]])
|
||||
|
||||
def test_extract_multiple_cross_encoder_results(self):
|
||||
"""Test extraction for multiple cross-encoder pairs."""
|
||||
input_ids = [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
|
||||
token_type_ids = [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]]
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=2
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result_input_ids, [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
|
||||
)
|
||||
self.assertEqual(result_token_type_ids, [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]])
|
||||
|
||||
def test_extract_empty_results(self):
|
||||
"""Test extraction with empty results."""
|
||||
input_ids = []
|
||||
token_type_ids = None
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
input_ids, token_type_ids, "single_string", original_batch_size=1
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result_input_ids, [])
|
||||
self.assertIsNone(result_token_type_ids)
|
||||
|
||||
def test_extract_with_none_token_type_ids(self):
|
||||
"""Test extraction when token_type_ids is None."""
|
||||
input_ids = [[101, 2129, 102]]
|
||||
token_type_ids = None
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
input_ids, token_type_ids, "single_string", original_batch_size=1
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result_input_ids, [101, 2129, 102])
|
||||
self.assertIsNone(result_token_type_ids)
|
||||
|
||||
|
||||
class TestTokenizerManagerIntegration(unittest.TestCase):
|
||||
"""Integration tests combining multiple helper methods."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
self.port_args = PortArgs.init_new(self.server_args)
|
||||
|
||||
with patch("zmq.asyncio.Context"), patch(
|
||||
"sglang.srt.utils.get_zmq_socket"
|
||||
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||
|
||||
def test_full_workflow_single_string(self):
|
||||
"""Test complete workflow for single string input."""
|
||||
text = "Hello world"
|
||||
|
||||
# Step 1: Detect format
|
||||
input_format = self.tokenizer_manager._detect_input_format(
|
||||
text, is_cross_encoder=False
|
||||
)
|
||||
self.assertEqual(input_format, "single_string")
|
||||
|
||||
# Step 2: Prepare input
|
||||
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
||||
text, input_format
|
||||
)
|
||||
self.assertEqual(tokenizer_input, ["Hello world"])
|
||||
|
||||
# Step 3: Extract results (simulated tokenizer output)
|
||||
mock_input_ids = [[101, 2129, 4248, 102]]
|
||||
mock_token_type_ids = None
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result_input_ids, [101, 2129, 4248, 102])
|
||||
self.assertIsNone(result_token_type_ids)
|
||||
|
||||
def test_full_workflow_cross_encoder_pairs(self):
|
||||
"""Test complete workflow for cross-encoder pairs."""
|
||||
texts = [
|
||||
["How many people live in Berlin?", "Berlin is well known for its museums."]
|
||||
]
|
||||
|
||||
# Step 1: Detect format
|
||||
input_format = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=True
|
||||
)
|
||||
self.assertEqual(input_format, "cross_encoder_pairs")
|
||||
|
||||
# Step 2: Prepare input
|
||||
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
||||
texts, input_format
|
||||
)
|
||||
self.assertEqual(tokenizer_input, texts)
|
||||
|
||||
# Step 3: Extract results (simulated tokenizer output for cross-encoder)
|
||||
mock_input_ids = [[101, 2129, 2116, 102, 4068, 2003, 102]]
|
||||
mock_token_type_ids = [[0, 0, 0, 0, 1, 1, 1]]
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(result_input_ids, [101, 2129, 2116, 102, 4068, 2003, 102])
|
||||
self.assertEqual(result_token_type_ids, [0, 0, 0, 0, 1, 1, 1])
|
||||
|
||||
def test_full_workflow_batch_strings(self):
|
||||
"""Test complete workflow for batch strings."""
|
||||
texts = ["Hello", "World", "Test"]
|
||||
|
||||
# Step 1: Detect format
|
||||
input_format = self.tokenizer_manager._detect_input_format(
|
||||
texts, is_cross_encoder=False
|
||||
)
|
||||
self.assertEqual(input_format, "batch_strings")
|
||||
|
||||
# Step 2: Prepare input
|
||||
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
||||
texts, input_format
|
||||
)
|
||||
self.assertEqual(tokenizer_input, ["Hello", "World", "Test"])
|
||||
|
||||
# Step 3: Extract results (simulated tokenizer output)
|
||||
mock_input_ids = [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
|
||||
mock_token_type_ids = None
|
||||
|
||||
result_input_ids, result_token_type_ids = (
|
||||
self.tokenizer_manager._extract_tokenizer_results(
|
||||
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=3
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
result_input_ids, [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
|
||||
)
|
||||
self.assertIsNone(result_token_type_ids)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user