diff --git a/python/sglang/srt/managers/async_dynamic_batch_tokenizer.py b/python/sglang/srt/managers/async_dynamic_batch_tokenizer.py new file mode 100644 index 000000000..ef1a8307f --- /dev/null +++ b/python/sglang/srt/managers/async_dynamic_batch_tokenizer.py @@ -0,0 +1,170 @@ +""" +Asynchronous dynamic batch tokenizer for SGLang. + +This module provides an async tokenizer with dynamic batching capabilities +to reduce tokenization overhead when multiple requests arrive concurrently. +""" + +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class AsyncDynamicbatchTokenizer: + """Asynchronous tokenizer with dynamic batching for single string prompts. + + Dynamically batches pending encode requests from a queue to reduce overhead. + Only handles single string prompts - regular batch processing of multiple + strings per request should be handled at a higher level. + A single-thread ThreadPoolExecutor is used so the event loop stays responsive. + + Note: Uses lazy initialization for asyncio components because this class + is instantiated in TokenizerManager.__init__() before the event loop starts. + """ + + def __init__( + self, + tokenizer, + max_batch_size: int = 32, + batch_wait_timeout_s: float = 0.002, + ) -> None: + self.tokenizer = tokenizer + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + + # Single queue for all encode requests - initialized lazily + self._queue: Optional[asyncio.Queue] = None + self._batcher_task: Optional[asyncio.Task] = None + + # Single-thread executor for blocking tokenizer calls + self._executor = ThreadPoolExecutor(max_workers=1) + self._initialized = False + + def _ensure_initialized(self): + """Lazy initialization of event loop dependent components.""" + if not self._initialized: + self._queue = asyncio.Queue() + self._batcher_task = asyncio.create_task(self._dynamic_batch_loop()) + self._initialized = True + + async def __call__(self, prompt: str, **kwargs) -> Any: + """Encode a single prompt.""" + return await self.encode(prompt, **kwargs) + + async def encode(self, prompt: str, **kwargs) -> Any: + """Encode a single prompt.""" + self._ensure_initialized() + result_future: asyncio.Future = asyncio.get_running_loop().create_future() + await self._queue.put((prompt, kwargs, result_future)) + return await result_future + + async def _dynamic_batch_loop(self): + """Dynamically batch incoming encode requests for efficiency.""" + while True: + try: + # Get the first request + prompt, kwargs, result_future = await self._queue.get() + + # Collect requests into dynamic batch + prompts = [prompt] + kwargs_list = [kwargs] + result_futures = [result_future] + + # Check if there are more items immediately available in the queue + # If queue is empty, process single item immediately without timeout + if self._queue.empty(): + # No other requests waiting, process immediately + pass + else: + # There might be more requests, wait for dynamic batching opportunity + start_time = asyncio.get_running_loop().time() + + # Collect more requests up to max_batch_size or batch_wait_timeout_s + while len(prompts) < self.max_batch_size: + elapsed = asyncio.get_running_loop().time() - start_time + if elapsed >= self.batch_wait_timeout_s: + break + + remaining_time = self.batch_wait_timeout_s - elapsed + try: + prompt, kwargs, result_future = await asyncio.wait_for( + self._queue.get(), remaining_time + ) + prompts.append(prompt) + kwargs_list.append(kwargs) + result_futures.append(result_future) + except asyncio.TimeoutError: + break + + # Log dynamic batch information + logger.debug( + f"AsyncDynamicbatchTokenizer: Processing dynamic batch of size {len(prompts)}" + ) + + # Process the dynamic batch + await self._process_dynamic_batch(prompts, kwargs_list, result_futures) + + except Exception as e: + logger.error(f"Error in dynamic batch loop: {e}") + # Continue the loop to handle other requests + + async def _process_dynamic_batch( + self, + prompts: List[str], + kwargs_list: List[Dict], + result_futures: List[asyncio.Future], + ) -> None: + """Process a dynamic batch of encode requests for single string prompts.""" + # Check if all kwargs are identical for efficient batch processing + can_batch = len(set(str(sorted(kw.items())) for kw in kwargs_list)) == 1 + kwargs = kwargs_list[0] if can_batch else None + + try: + # If every request uses identical kwargs we can run a single + # batch tokenizer call for a big speed-up. + if can_batch and len(prompts) > 1: + encode_fn = partial(self.tokenizer, prompts, **kwargs) + results = await asyncio.get_running_loop().run_in_executor( + self._executor, encode_fn + ) + + for i, fut in enumerate(result_futures): + if not fut.done(): + data = {k: v[i] for k, v in results.items()} + fut.set_result(data) + else: + # Process each request individually due to different kwargs + if len(prompts) > 1 and not can_batch: + logger.warning( + f"AsyncDynamicbatchTokenizer: Dynamic batching disabled for batch of {len(prompts)} " + f"requests due to differing kwargs. This reduces performance benefits. " + f"Consider using consistent tokenization parameters across requests." + ) + + encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ + self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs_list) + ] + results = await asyncio.get_running_loop().run_in_executor( + self._executor, encode_fn + ) + + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + logger.error(f"Error in dynamic batch processing: {e}") + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + def __del__(self): + """Clean up background tasks.""" + if hasattr(self, "_batcher_task") and self._batcher_task: + if not self._batcher_task.done(): + self._batcher_task.cancel() + if hasattr(self, "_executor"): + self._executor.shutdown(wait=False) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 4812ca180..730ee05aa 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -49,6 +49,7 @@ from sglang.srt.hf_transformers_utils import ( get_tokenizer_from_processor, ) from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry +from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.io_struct import ( AbortReq, @@ -216,6 +217,18 @@ class TokenizerManager(TokenizerCommunicatorMixin): trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, ) + # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal) + if ( + server_args.enable_dynamic_batch_tokenizer + and not server_args.skip_tokenizer_init + ): + self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer( + self.tokenizer, + max_batch_size=server_args.dynamic_batch_tokenizer_batch_size, + batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout, + ) + else: + self.async_dynamic_batch_tokenizer = None # Init inter-process communication context = zmq.asyncio.Context(2) @@ -370,6 +383,144 @@ class TokenizerManager(TokenizerCommunicatorMixin): ): yield response + def _detect_input_format( + self, texts: Union[str, List[str]], is_cross_encoder: bool + ) -> str: + """Detect the format of input texts for proper tokenization handling. + + Returns: + - "single_string": Regular single text like "Hello world" + - "batch_strings": Regular batch like ["Hello", "World"] + - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]] + """ + if isinstance(texts, str): + return "single_string" + + if ( + is_cross_encoder + and len(texts) > 0 + and isinstance(texts[0], list) + and len(texts[0]) == 2 + ): + return "cross_encoder_pairs" + + return "batch_strings" + + def _prepare_tokenizer_input( + self, texts: Union[str, List[str]], input_format: str + ) -> Union[List[str], List[List[str]]]: + """Prepare input for the tokenizer based on detected format.""" + if input_format == "single_string": + return [texts] # Wrap single string for batch processing + elif input_format == "cross_encoder_pairs": + return texts # Already in correct format: [["query", "doc"]] + else: # batch_strings + return texts # Already in correct format: ["text1", "text2"] + + def _extract_tokenizer_results( + self, + input_ids: List[List[int]], + token_type_ids: Optional[List[List[int]]], + input_format: str, + original_batch_size: int, + ) -> Union[ + Tuple[List[int], Optional[List[int]]], + Tuple[List[List[int]], Optional[List[List[int]]]], + ]: + """Extract results from tokenizer output based on input format.""" + + # For single inputs (string or single cross-encoder pair), extract first element + if ( + input_format in ["single_string", "cross_encoder_pairs"] + and original_batch_size == 1 + ): + single_input_ids = input_ids[0] if input_ids else [] + single_token_type_ids = token_type_ids[0] if token_type_ids else None + return single_input_ids, single_token_type_ids + + # For true batches, return as-is + return input_ids, token_type_ids + + async def _tokenize_texts( + self, texts: Union[str, List[str]], is_cross_encoder: bool = False + ) -> Union[ + Tuple[List[int], Optional[List[int]]], + Tuple[List[List[int]], Optional[List[List[int]]]], + ]: + """ + Tokenize text(s) using the appropriate tokenizer strategy. + + This method handles multiple input formats and chooses between async dynamic + batch tokenizer (for single texts only) and regular tokenizer. + + Args: + texts: Text input in various formats: + + Regular cases: + - Single string: "How are you?" + - Batch of strings: ["Hello", "World", "How are you?"] + + Cross-encoder cases (sentence pairs for similarity/ranking): + - Single pair: [["query text", "document text"]] + - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]] + + is_cross_encoder: Whether to return token_type_ids for cross-encoder models. + Enables proper handling of sentence pairs with segment IDs. + + Returns: + Single input cases: + Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids) + Example: ([101, 2129, 102], [0, 0, 0]) for single text + Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair + + Batch input cases: + Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids) + Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch + + Note: token_type_ids is None unless is_cross_encoder=True. + """ + if not texts or self.tokenizer is None: + raise ValueError("texts cannot be empty and tokenizer must be initialized") + + # Step 1: Detect input format and prepare for tokenization + input_format = self._detect_input_format(texts, is_cross_encoder) + tokenizer_input = self._prepare_tokenizer_input(texts, input_format) + original_batch_size = len(texts) if not isinstance(texts, str) else 1 + + # Step 2: Set up tokenizer arguments + tokenizer_kwargs = ( + {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {} + ) + + # Step 3: Choose tokenization strategy + use_async_tokenizer = ( + self.async_dynamic_batch_tokenizer is not None + and input_format == "single_string" + ) + + if use_async_tokenizer: + logger.debug("Using async dynamic batch tokenizer for single text") + result = await self.async_dynamic_batch_tokenizer.encode( + tokenizer_input[0], **tokenizer_kwargs + ) + # Convert to batch format for consistency + input_ids = [result["input_ids"]] + token_type_ids = ( + [result["token_type_ids"]] + if is_cross_encoder and result.get("token_type_ids") + else None + ) + else: + logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs") + encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs) + input_ids = encoded["input_ids"] + token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None + + # Step 4: Extract results based on input format + return self._extract_tokenizer_results( + input_ids, token_type_ids, input_format, original_batch_size + ) + async def _tokenize_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -400,14 +551,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): "accept text prompts. Please provide input_ids or re-initialize " "the engine with skip_tokenizer_init=False." ) - encoded = self.tokenizer( - input_text, return_token_type_ids=is_cross_encoder_request - ) - input_ids = encoded["input_ids"] - if is_cross_encoder_request: - input_ids = encoded["input_ids"][0] - token_type_ids = encoded.get("token_type_ids", [None])[0] + input_ids, token_type_ids = await self._tokenize_texts( + input_text, is_cross_encoder_request + ) if self.mm_processor and obj.contains_mm_input(): if not isinstance(obj.image_data, list): @@ -582,17 +729,27 @@ class TokenizerManager(TokenizerCommunicatorMixin): requests = [obj[i] for i in range(batch_size)] texts = [req.text for req in requests] - # Batch tokenize all texts - encoded = self.tokenizer(texts) - input_ids_list = encoded["input_ids"] + # Check if any request is a cross-encoder request + is_cross_encoder_request = any( + isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request + for req in requests + ) + + # Batch tokenize all texts using unified method + input_ids_list, token_type_ids_list = await self._tokenize_texts( + texts, is_cross_encoder_request + ) # Process all requests tokenized_objs = [] for i, req in enumerate(requests): self._validate_one_request(obj[i], input_ids_list[i]) + token_type_ids = ( + token_type_ids_list[i] if token_type_ids_list is not None else None + ) tokenized_objs.append( self._create_tokenized_object( - req, req.text, input_ids_list[i], None, None + req, req.text, input_ids_list[i], None, None, token_type_ids ) ) logger.debug(f"Completed batch processing for {batch_size} requests") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 897c732b5..2e2fc458b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -373,6 +373,11 @@ class ServerArgs: scheduler_recv_interval: int = 1 numa_node: Optional[List[int]] = None + # Dynamic batch tokenizer + enable_dynamic_batch_tokenizer: bool = False + dynamic_batch_tokenizer_batch_size: int = 32 + dynamic_batch_tokenizer_batch_timeout: float = 0.002 + # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_input_file: Optional[str] = None @@ -874,6 +879,13 @@ class ServerArgs: self.disable_cuda_graph = True logger.warning("Cuda graph is disabled for prefill server") + # Validation: prevent both tokenizer batching features from being enabled + if self.enable_tokenizer_batch_encode and self.enable_dynamic_batch_tokenizer: + raise ValueError( + "Cannot enable both --enable-tokenizer-batch-encode and --enable-dynamic-batch-tokenizer. " + "Please choose one tokenizer batching approach." + ) + # Propagate env vars os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( "1" if self.enable_torch_compile else "0" @@ -2162,6 +2174,23 @@ class ServerArgs: action="store_true", help="Only dump the tensors for prefill requests (i.e. batch size > 1).", ) + parser.add_argument( + "--enable-dynamic-batch-tokenizer", + action="store_true", + help="Enable async dynamic batch tokenizer for improved performance when multiple requests arrive concurrently.", + ) + parser.add_argument( + "--dynamic-batch-tokenizer-batch-size", + type=int, + default=ServerArgs.dynamic_batch_tokenizer_batch_size, + help="[Only used if --enable-dynamic-batch-tokenizer is set] Maximum batch size for dynamic batch tokenizer.", + ) + parser.add_argument( + "--dynamic-batch-tokenizer-batch-timeout", + type=float, + default=ServerArgs.dynamic_batch_tokenizer_batch_timeout, + help="[Only used if --enable-dynamic-batch-tokenizer is set] Timeout in seconds for batching tokenization requests.", + ) # PD disaggregation parser.add_argument( diff --git a/test/srt/test_async_dynamic_batch_tokenizer.py b/test/srt/test_async_dynamic_batch_tokenizer.py new file mode 100644 index 000000000..930e23e54 --- /dev/null +++ b/test/srt/test_async_dynamic_batch_tokenizer.py @@ -0,0 +1,295 @@ +""" +Unit tests for AsyncDynamicbatchTokenizer. + +Tests the async dynamic batching functionality for tokenization, +including batch efficiency, timeout handling, and error cases. +""" + +import asyncio +import logging +import time +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from transformers import AutoTokenizer + +from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer + + +class TestAsyncDynamicbatchTokenizer: + """Test suite for AsyncDynamicbatchTokenizer.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer that behaves like HuggingFace tokenizer.""" + + def mock_encode(texts, **kwargs): + is_single = isinstance(texts, str) + if is_single: + texts = [texts] + + # Simulate tokenization - convert text to mock token ids + input_ids = [] + token_type_ids = [] + + for text in texts: + # Simple mock: text length determines number of tokens + tokens = [i for i in range(len(text.split()))] + input_ids.append(tokens) + + if kwargs.get("return_token_type_ids", False): + token_type_ids.append([0] * len(tokens)) + + result = {"input_ids": input_ids} + if kwargs.get("return_token_type_ids", False): + result["token_type_ids"] = token_type_ids + + # For single inputs, return individual result (not wrapped in a list) + if is_single: + result = {"input_ids": input_ids[0]} + if kwargs.get("return_token_type_ids", False): + result["token_type_ids"] = token_type_ids[0] + + # Create a proper BatchEncoding-like object that supports dict operations + class MockBatchEncoding(dict): + def __init__(self, data): + super().__init__(data) + for key, value in data.items(): + setattr(self, key, value) + + return MockBatchEncoding(result) + + # Return the function directly - the AsyncDynamicbatchTokenizer will call it + return mock_encode + + @pytest.fixture + def async_tokenizer(self, mock_tokenizer): + """Create AsyncDynamicbatchTokenizer instance.""" + return AsyncDynamicbatchTokenizer( + tokenizer=mock_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01 + ) + + @pytest.mark.asyncio + async def test_single_request(self, async_tokenizer): + """Test tokenizing a single request.""" + text = "hello world" + result = await async_tokenizer.encode(text) + + assert "input_ids" in result + assert result["input_ids"] == [0, 1] # 2 words -> 2 tokens + + @pytest.mark.asyncio + async def test_single_request_with_token_type_ids(self, async_tokenizer): + """Test tokenizing with token type IDs.""" + text = "hello world" + result = await async_tokenizer.encode(text, return_token_type_ids=True) + + assert "input_ids" in result + assert "token_type_ids" in result + assert result["input_ids"] == [0, 1] + assert result["token_type_ids"] == [0, 0] + + @pytest.mark.asyncio + async def test_concurrent_requests_same_kwargs(self, async_tokenizer): + """Test that concurrent requests with same kwargs get batched.""" + texts = ["hello world", "how are you", "fine thanks", "good morning"] + + # Start all requests concurrently + tasks = [async_tokenizer.encode(text) for text in texts] + results = await asyncio.gather(*tasks) + + # Verify all results + assert len(results) == 4 + for i, result in enumerate(results): + assert "input_ids" in result + expected_tokens = list(range(len(texts[i].split()))) + assert result["input_ids"] == expected_tokens + + @pytest.mark.asyncio + async def test_concurrent_requests_different_kwargs(self, async_tokenizer): + """Test that requests with different kwargs are processed individually.""" + text1 = "hello world" + text2 = "how are you" + + # One with token_type_ids, one without + task1 = async_tokenizer.encode(text1, return_token_type_ids=True) + task2 = async_tokenizer.encode(text2) + + result1, result2 = await asyncio.gather(task1, task2) + + # First result should have token_type_ids + assert "input_ids" in result1 + assert "token_type_ids" in result1 + assert result1["input_ids"] == [0, 1] + assert result1["token_type_ids"] == [0, 0] + + # Second result should not have token_type_ids + assert "input_ids" in result2 + assert "token_type_ids" not in result2 + assert result2["input_ids"] == [0, 1, 2] + + @pytest.mark.asyncio + async def test_batch_timeout(self, async_tokenizer): + """Test that batching respects timeout.""" + # Send first request + task1 = asyncio.create_task(async_tokenizer.encode("hello world")) + + # Wait longer than batch timeout + await asyncio.sleep(0.02) # Longer than 0.01s timeout + + # Send second request + task2 = asyncio.create_task(async_tokenizer.encode("how are you")) + + results = await asyncio.gather(task1, task2) + + # Both should complete successfully + assert len(results) == 2 + assert results[0]["input_ids"] == [0, 1] + assert results[1]["input_ids"] == [0, 1, 2] + + @pytest.mark.asyncio + async def test_max_batch_size_limit(self, async_tokenizer): + """Test that batching respects max_batch_size.""" + # Send more requests than max_batch_size (4) + texts = [f"text {i}" for i in range(6)] + tasks = [async_tokenizer.encode(text) for text in texts] + + results = await asyncio.gather(*tasks) + + # All should complete successfully + assert len(results) == 6 + for i, result in enumerate(results): + assert "input_ids" in result + assert result["input_ids"] == [0, 1] # "text i" -> 2 tokens + + @pytest.mark.asyncio + async def test_callable_interface(self, async_tokenizer): + """Test that the tokenizer is callable.""" + text = "hello world" + result = await async_tokenizer(text) + + assert "input_ids" in result + assert result["input_ids"] == [0, 1] + + @pytest.mark.asyncio + async def test_lazy_initialization(self, mock_tokenizer): + """Test that initialization happens lazily.""" + tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer) + + # Should not be initialized yet + assert not tokenizer._initialized + + # First encode should initialize + await tokenizer.encode("hello") + + # Should now be initialized + assert tokenizer._initialized + + @pytest.mark.asyncio + async def test_error_handling_in_tokenizer(self, mock_tokenizer): + """Test error handling when tokenizer fails.""" + + # Create a new async tokenizer with a failing tokenizer + def failing_tokenizer(*args, **kwargs): + raise ValueError("Tokenizer error") + + async_tokenizer = AsyncDynamicbatchTokenizer( + tokenizer=failing_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01 + ) + + with pytest.raises(ValueError, match="Tokenizer error"): + await async_tokenizer.encode("hello world") + + @pytest.mark.asyncio + async def test_batch_processing_logs(self, async_tokenizer, caplog): + """Test that batch processing logs are generated.""" + caplog.set_level(logging.DEBUG) + + # Send multiple requests to trigger batching + tasks = [ + async_tokenizer.encode("hello world"), + async_tokenizer.encode("how are you"), + ] + + await asyncio.gather(*tasks) + + # Should have batch processing log + assert any( + "Processing dynamic batch of size" in record.message + for record in caplog.records + ) + + @pytest.mark.asyncio + async def test_empty_queue_immediate_processing(self, async_tokenizer): + """Test that single requests are processed immediately when queue is empty.""" + start_time = time.time() + result = await async_tokenizer.encode("hello world") + end_time = time.time() + + # Should complete quickly (much less than batch timeout) + assert end_time - start_time < 0.005 # 5ms should be plenty + assert result["input_ids"] == [0, 1] + + @pytest.mark.asyncio + async def test_real_tokenizer_integration(self): + """Test with a real HuggingFace tokenizer.""" + try: + # Use a small, fast tokenizer for testing + real_tokenizer = AutoTokenizer.from_pretrained("gpt2") + async_tokenizer = AsyncDynamicbatchTokenizer( + tokenizer=real_tokenizer, max_batch_size=2, batch_wait_timeout_s=0.01 + ) + + text = "Hello, world!" + result = await async_tokenizer.encode(text) + + # Should get actual token IDs + assert "input_ids" in result + assert isinstance(result["input_ids"], list) + assert len(result["input_ids"]) > 0 + assert all(isinstance(token_id, int) for token_id in result["input_ids"]) + + except Exception as e: + pytest.skip(f"Real tokenizer test skipped: {e}") + + @pytest.mark.asyncio + async def test_concurrent_mixed_requests(self, async_tokenizer): + """Test mixing single and batched requests.""" + # Start some requests + task1 = asyncio.create_task(async_tokenizer.encode("hello")) + task2 = asyncio.create_task(async_tokenizer.encode("world")) + + # Wait a bit + await asyncio.sleep(0.005) + + # Start more requests + task3 = asyncio.create_task(async_tokenizer.encode("how are")) + task4 = asyncio.create_task(async_tokenizer.encode("you doing")) + + results = await asyncio.gather(task1, task2, task3, task4) + + # All should complete successfully + assert len(results) == 4 + for result in results: + assert "input_ids" in result + assert isinstance(result["input_ids"], list) + + def test_cleanup_on_destruction(self, mock_tokenizer): + """Test that resources are cleaned up properly.""" + tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer) + + # Mock the executor and task + tokenizer._executor = Mock() + tokenizer._batcher_task = Mock() + tokenizer._batcher_task.done.return_value = False + + # Call destructor + tokenizer.__del__() + + # Should cancel task and shutdown executor + tokenizer._batcher_task.cancel.assert_called_once() + tokenizer._executor.shutdown.assert_called_once_with(wait=False) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/test_tokenizer_manager.py b/test/srt/test_tokenizer_manager.py new file mode 100644 index 000000000..a12987279 --- /dev/null +++ b/test/srt/test_tokenizer_manager.py @@ -0,0 +1,379 @@ +""" +Unit tests for TokenizerManager helper methods. + +This tests the refactored tokenization functionality including input format detection, +tokenizer input preparation, and result extraction logic. + +Usage: +python3 -m unittest test_tokenizer_manager.TestInputFormatDetection +python3 -m unittest test_tokenizer_manager.TestTokenizerInputPreparation +python3 -m unittest test_tokenizer_manager.TestTokenizerResultExtraction +python3 -m unittest test_tokenizer_manager.TestTokenizerManagerIntegration +""" + +import unittest +from typing import List, Optional, Union +from unittest.mock import Mock, patch + +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + +class TestInputFormatDetection(unittest.TestCase): + """Test cases for _detect_input_format method.""" + + def setUp(self): + """Set up test fixtures.""" + with patch("sglang.srt.utils.get_device", return_value="cpu"): + self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + self.port_args = PortArgs.init_new(self.server_args) + + with patch("zmq.asyncio.Context"), patch( + "sglang.srt.utils.get_zmq_socket" + ), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer: + mock_tokenizer.return_value = Mock(vocab_size=32000) + self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) + + def test_detect_single_string(self): + """Test detection of single string input.""" + text = "Hello world" + result = self.tokenizer_manager._detect_input_format( + text, is_cross_encoder=False + ) + self.assertEqual(result, "single_string") + + def test_detect_single_string_cross_encoder_disabled(self): + """Test single string with cross_encoder disabled still returns single_string.""" + text = "Hello world" + result = self.tokenizer_manager._detect_input_format( + text, is_cross_encoder=True + ) + self.assertEqual(result, "single_string") + + def test_detect_batch_strings(self): + """Test detection of batch string inputs.""" + texts = ["Hello", "World", "How are you?"] + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=False + ) + self.assertEqual(result, "batch_strings") + + def test_detect_batch_strings_cross_encoder_disabled(self): + """Test batch strings with cross_encoder disabled.""" + texts = ["Hello", "World"] + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=True + ) + self.assertEqual(result, "batch_strings") + + def test_detect_cross_encoder_single_pair(self): + """Test detection of cross-encoder single pair.""" + texts = [["query text", "document text"]] + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=True + ) + self.assertEqual(result, "cross_encoder_pairs") + + def test_detect_cross_encoder_multiple_pairs(self): + """Test detection of cross-encoder multiple pairs.""" + texts = [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]] + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=True + ) + self.assertEqual(result, "cross_encoder_pairs") + + def test_detect_cross_encoder_disabled_with_pairs(self): + """Test pairs with cross_encoder disabled should return batch_strings.""" + texts = [["query", "document"]] + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=False + ) + self.assertEqual(result, "batch_strings") + + def test_detect_empty_list(self): + """Test detection with empty list.""" + texts = [] + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=True + ) + self.assertEqual(result, "batch_strings") + + def test_detect_malformed_cross_encoder_pairs(self): + """Test malformed cross-encoder pairs (not length 2).""" + texts = [["query only"]] # Single element, not a pair + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=True + ) + self.assertEqual(result, "batch_strings") + + texts = [["query", "doc", "extra"]] # Three elements, not a pair + result = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=True + ) + self.assertEqual(result, "batch_strings") + + +class TestTokenizerInputPreparation(unittest.TestCase): + """Test cases for _prepare_tokenizer_input method.""" + + def setUp(self): + """Set up test fixtures.""" + with patch("sglang.srt.utils.get_device", return_value="cpu"): + self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + self.port_args = PortArgs.init_new(self.server_args) + + with patch("zmq.asyncio.Context"), patch( + "sglang.srt.utils.get_zmq_socket" + ), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer: + mock_tokenizer.return_value = Mock(vocab_size=32000) + self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) + + def test_prepare_single_string_input(self): + """Test preparation of single string input.""" + text = "Hello world" + result = self.tokenizer_manager._prepare_tokenizer_input(text, "single_string") + self.assertEqual(result, ["Hello world"]) + + def test_prepare_batch_strings_input(self): + """Test preparation of batch strings input.""" + texts = ["Hello", "World", "Test"] + result = self.tokenizer_manager._prepare_tokenizer_input(texts, "batch_strings") + self.assertEqual(result, ["Hello", "World", "Test"]) + + def test_prepare_cross_encoder_pairs_input(self): + """Test preparation of cross-encoder pairs input.""" + texts = [["query1", "doc1"], ["query2", "doc2"]] + result = self.tokenizer_manager._prepare_tokenizer_input( + texts, "cross_encoder_pairs" + ) + self.assertEqual(result, [["query1", "doc1"], ["query2", "doc2"]]) + + def test_prepare_cross_encoder_single_pair_input(self): + """Test preparation of single cross-encoder pair.""" + texts = [["query text", "document text"]] + result = self.tokenizer_manager._prepare_tokenizer_input( + texts, "cross_encoder_pairs" + ) + self.assertEqual(result, [["query text", "document text"]]) + + def test_prepare_unknown_input_format(self): + """Test preparation with unknown input format falls back to returning as-is.""" + texts = ["test"] + result = self.tokenizer_manager._prepare_tokenizer_input( + texts, "unknown_format" + ) + self.assertEqual(result, ["test"]) + + +class TestTokenizerResultExtraction(unittest.TestCase): + """Test cases for _extract_tokenizer_results method.""" + + def setUp(self): + """Set up test fixtures.""" + with patch("sglang.srt.utils.get_device", return_value="cpu"): + self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + self.port_args = PortArgs.init_new(self.server_args) + + with patch("zmq.asyncio.Context"), patch( + "sglang.srt.utils.get_zmq_socket" + ), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer: + mock_tokenizer.return_value = Mock(vocab_size=32000) + self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) + + def test_extract_single_string_results(self): + """Test extraction for single string input.""" + input_ids = [[101, 2129, 102]] + token_type_ids = [[0, 0, 0]] + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + input_ids, token_type_ids, "single_string", original_batch_size=1 + ) + ) + + self.assertEqual(result_input_ids, [101, 2129, 102]) + self.assertEqual(result_token_type_ids, [0, 0, 0]) + + def test_extract_single_cross_encoder_results(self): + """Test extraction for single cross-encoder pair.""" + input_ids = [[101, 2129, 102, 4068, 102]] + token_type_ids = [[0, 0, 0, 1, 1]] + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=1 + ) + ) + + self.assertEqual(result_input_ids, [101, 2129, 102, 4068, 102]) + self.assertEqual(result_token_type_ids, [0, 0, 0, 1, 1]) + + def test_extract_batch_results(self): + """Test extraction for batch inputs.""" + input_ids = [[101, 2129, 102], [101, 4068, 102]] + token_type_ids = [[0, 0, 0], [0, 0, 0]] + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + input_ids, token_type_ids, "batch_strings", original_batch_size=2 + ) + ) + + self.assertEqual(result_input_ids, [[101, 2129, 102], [101, 4068, 102]]) + self.assertEqual(result_token_type_ids, [[0, 0, 0], [0, 0, 0]]) + + def test_extract_multiple_cross_encoder_results(self): + """Test extraction for multiple cross-encoder pairs.""" + input_ids = [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]] + token_type_ids = [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]] + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=2 + ) + ) + + self.assertEqual( + result_input_ids, [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]] + ) + self.assertEqual(result_token_type_ids, [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]]) + + def test_extract_empty_results(self): + """Test extraction with empty results.""" + input_ids = [] + token_type_ids = None + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + input_ids, token_type_ids, "single_string", original_batch_size=1 + ) + ) + + self.assertEqual(result_input_ids, []) + self.assertIsNone(result_token_type_ids) + + def test_extract_with_none_token_type_ids(self): + """Test extraction when token_type_ids is None.""" + input_ids = [[101, 2129, 102]] + token_type_ids = None + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + input_ids, token_type_ids, "single_string", original_batch_size=1 + ) + ) + + self.assertEqual(result_input_ids, [101, 2129, 102]) + self.assertIsNone(result_token_type_ids) + + +class TestTokenizerManagerIntegration(unittest.TestCase): + """Integration tests combining multiple helper methods.""" + + def setUp(self): + """Set up test fixtures.""" + with patch("sglang.srt.utils.get_device", return_value="cpu"): + self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + self.port_args = PortArgs.init_new(self.server_args) + + with patch("zmq.asyncio.Context"), patch( + "sglang.srt.utils.get_zmq_socket" + ), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer: + mock_tokenizer.return_value = Mock(vocab_size=32000) + self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args) + + def test_full_workflow_single_string(self): + """Test complete workflow for single string input.""" + text = "Hello world" + + # Step 1: Detect format + input_format = self.tokenizer_manager._detect_input_format( + text, is_cross_encoder=False + ) + self.assertEqual(input_format, "single_string") + + # Step 2: Prepare input + tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input( + text, input_format + ) + self.assertEqual(tokenizer_input, ["Hello world"]) + + # Step 3: Extract results (simulated tokenizer output) + mock_input_ids = [[101, 2129, 4248, 102]] + mock_token_type_ids = None + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1 + ) + ) + + self.assertEqual(result_input_ids, [101, 2129, 4248, 102]) + self.assertIsNone(result_token_type_ids) + + def test_full_workflow_cross_encoder_pairs(self): + """Test complete workflow for cross-encoder pairs.""" + texts = [ + ["How many people live in Berlin?", "Berlin is well known for its museums."] + ] + + # Step 1: Detect format + input_format = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=True + ) + self.assertEqual(input_format, "cross_encoder_pairs") + + # Step 2: Prepare input + tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input( + texts, input_format + ) + self.assertEqual(tokenizer_input, texts) + + # Step 3: Extract results (simulated tokenizer output for cross-encoder) + mock_input_ids = [[101, 2129, 2116, 102, 4068, 2003, 102]] + mock_token_type_ids = [[0, 0, 0, 0, 1, 1, 1]] + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1 + ) + ) + + self.assertEqual(result_input_ids, [101, 2129, 2116, 102, 4068, 2003, 102]) + self.assertEqual(result_token_type_ids, [0, 0, 0, 0, 1, 1, 1]) + + def test_full_workflow_batch_strings(self): + """Test complete workflow for batch strings.""" + texts = ["Hello", "World", "Test"] + + # Step 1: Detect format + input_format = self.tokenizer_manager._detect_input_format( + texts, is_cross_encoder=False + ) + self.assertEqual(input_format, "batch_strings") + + # Step 2: Prepare input + tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input( + texts, input_format + ) + self.assertEqual(tokenizer_input, ["Hello", "World", "Test"]) + + # Step 3: Extract results (simulated tokenizer output) + mock_input_ids = [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]] + mock_token_type_ids = None + + result_input_ids, result_token_type_ids = ( + self.tokenizer_manager._extract_tokenizer_results( + mock_input_ids, mock_token_type_ids, input_format, original_batch_size=3 + ) + ) + + self.assertEqual( + result_input_ids, [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]] + ) + self.assertIsNone(result_token_type_ids) + + +if __name__ == "__main__": + unittest.main(verbosity=2)