[Performance] Dynamic Batch Tokenizer (#9382)
This commit is contained in:
committed by
GitHub
parent
eca59f96c3
commit
94d0f656fb
170
python/sglang/srt/managers/async_dynamic_batch_tokenizer.py
Normal file
170
python/sglang/srt/managers/async_dynamic_batch_tokenizer.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
Asynchronous dynamic batch tokenizer for SGLang.
|
||||||
|
|
||||||
|
This module provides an async tokenizer with dynamic batching capabilities
|
||||||
|
to reduce tokenization overhead when multiple requests arrive concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncDynamicbatchTokenizer:
|
||||||
|
"""Asynchronous tokenizer with dynamic batching for single string prompts.
|
||||||
|
|
||||||
|
Dynamically batches pending encode requests from a queue to reduce overhead.
|
||||||
|
Only handles single string prompts - regular batch processing of multiple
|
||||||
|
strings per request should be handled at a higher level.
|
||||||
|
A single-thread ThreadPoolExecutor is used so the event loop stays responsive.
|
||||||
|
|
||||||
|
Note: Uses lazy initialization for asyncio components because this class
|
||||||
|
is instantiated in TokenizerManager.__init__() before the event loop starts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer,
|
||||||
|
max_batch_size: int = 32,
|
||||||
|
batch_wait_timeout_s: float = 0.002,
|
||||||
|
) -> None:
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.batch_wait_timeout_s = batch_wait_timeout_s
|
||||||
|
|
||||||
|
# Single queue for all encode requests - initialized lazily
|
||||||
|
self._queue: Optional[asyncio.Queue] = None
|
||||||
|
self._batcher_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
# Single-thread executor for blocking tokenizer calls
|
||||||
|
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def _ensure_initialized(self):
|
||||||
|
"""Lazy initialization of event loop dependent components."""
|
||||||
|
if not self._initialized:
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
self._batcher_task = asyncio.create_task(self._dynamic_batch_loop())
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
async def __call__(self, prompt: str, **kwargs) -> Any:
|
||||||
|
"""Encode a single prompt."""
|
||||||
|
return await self.encode(prompt, **kwargs)
|
||||||
|
|
||||||
|
async def encode(self, prompt: str, **kwargs) -> Any:
|
||||||
|
"""Encode a single prompt."""
|
||||||
|
self._ensure_initialized()
|
||||||
|
result_future: asyncio.Future = asyncio.get_running_loop().create_future()
|
||||||
|
await self._queue.put((prompt, kwargs, result_future))
|
||||||
|
return await result_future
|
||||||
|
|
||||||
|
async def _dynamic_batch_loop(self):
|
||||||
|
"""Dynamically batch incoming encode requests for efficiency."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Get the first request
|
||||||
|
prompt, kwargs, result_future = await self._queue.get()
|
||||||
|
|
||||||
|
# Collect requests into dynamic batch
|
||||||
|
prompts = [prompt]
|
||||||
|
kwargs_list = [kwargs]
|
||||||
|
result_futures = [result_future]
|
||||||
|
|
||||||
|
# Check if there are more items immediately available in the queue
|
||||||
|
# If queue is empty, process single item immediately without timeout
|
||||||
|
if self._queue.empty():
|
||||||
|
# No other requests waiting, process immediately
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# There might be more requests, wait for dynamic batching opportunity
|
||||||
|
start_time = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
|
# Collect more requests up to max_batch_size or batch_wait_timeout_s
|
||||||
|
while len(prompts) < self.max_batch_size:
|
||||||
|
elapsed = asyncio.get_running_loop().time() - start_time
|
||||||
|
if elapsed >= self.batch_wait_timeout_s:
|
||||||
|
break
|
||||||
|
|
||||||
|
remaining_time = self.batch_wait_timeout_s - elapsed
|
||||||
|
try:
|
||||||
|
prompt, kwargs, result_future = await asyncio.wait_for(
|
||||||
|
self._queue.get(), remaining_time
|
||||||
|
)
|
||||||
|
prompts.append(prompt)
|
||||||
|
kwargs_list.append(kwargs)
|
||||||
|
result_futures.append(result_future)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Log dynamic batch information
|
||||||
|
logger.debug(
|
||||||
|
f"AsyncDynamicbatchTokenizer: Processing dynamic batch of size {len(prompts)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the dynamic batch
|
||||||
|
await self._process_dynamic_batch(prompts, kwargs_list, result_futures)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in dynamic batch loop: {e}")
|
||||||
|
# Continue the loop to handle other requests
|
||||||
|
|
||||||
|
async def _process_dynamic_batch(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
kwargs_list: List[Dict],
|
||||||
|
result_futures: List[asyncio.Future],
|
||||||
|
) -> None:
|
||||||
|
"""Process a dynamic batch of encode requests for single string prompts."""
|
||||||
|
# Check if all kwargs are identical for efficient batch processing
|
||||||
|
can_batch = len(set(str(sorted(kw.items())) for kw in kwargs_list)) == 1
|
||||||
|
kwargs = kwargs_list[0] if can_batch else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# If every request uses identical kwargs we can run a single
|
||||||
|
# batch tokenizer call for a big speed-up.
|
||||||
|
if can_batch and len(prompts) > 1:
|
||||||
|
encode_fn = partial(self.tokenizer, prompts, **kwargs)
|
||||||
|
results = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
self._executor, encode_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, fut in enumerate(result_futures):
|
||||||
|
if not fut.done():
|
||||||
|
data = {k: v[i] for k, v in results.items()}
|
||||||
|
fut.set_result(data)
|
||||||
|
else:
|
||||||
|
# Process each request individually due to different kwargs
|
||||||
|
if len(prompts) > 1 and not can_batch:
|
||||||
|
logger.warning(
|
||||||
|
f"AsyncDynamicbatchTokenizer: Dynamic batching disabled for batch of {len(prompts)} "
|
||||||
|
f"requests due to differing kwargs. This reduces performance benefits. "
|
||||||
|
f"Consider using consistent tokenization parameters across requests."
|
||||||
|
)
|
||||||
|
|
||||||
|
encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [
|
||||||
|
self.tokenizer(p, **kw) for p, kw in zip(prompts, kwargs_list)
|
||||||
|
]
|
||||||
|
results = await asyncio.get_running_loop().run_in_executor(
|
||||||
|
self._executor, encode_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
for fut, res in zip(result_futures, results):
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_result(res)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in dynamic batch processing: {e}")
|
||||||
|
for fut in result_futures:
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_exception(e)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Clean up background tasks."""
|
||||||
|
if hasattr(self, "_batcher_task") and self._batcher_task:
|
||||||
|
if not self._batcher_task.done():
|
||||||
|
self._batcher_task.cancel()
|
||||||
|
if hasattr(self, "_executor"):
|
||||||
|
self._executor.shutdown(wait=False)
|
||||||
@@ -49,6 +49,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
get_tokenizer_from_processor,
|
get_tokenizer_from_processor,
|
||||||
)
|
)
|
||||||
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
||||||
|
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
||||||
from sglang.srt.managers.disagg_service import start_disagg_service
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
@@ -216,6 +217,18 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
)
|
)
|
||||||
|
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
||||||
|
if (
|
||||||
|
server_args.enable_dynamic_batch_tokenizer
|
||||||
|
and not server_args.skip_tokenizer_init
|
||||||
|
):
|
||||||
|
self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
|
||||||
|
self.tokenizer,
|
||||||
|
max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
|
||||||
|
batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.async_dynamic_batch_tokenizer = None
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.asyncio.Context(2)
|
context = zmq.asyncio.Context(2)
|
||||||
@@ -370,6 +383,144 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
|
def _detect_input_format(
|
||||||
|
self, texts: Union[str, List[str]], is_cross_encoder: bool
|
||||||
|
) -> str:
|
||||||
|
"""Detect the format of input texts for proper tokenization handling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- "single_string": Regular single text like "Hello world"
|
||||||
|
- "batch_strings": Regular batch like ["Hello", "World"]
|
||||||
|
- "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
|
||||||
|
"""
|
||||||
|
if isinstance(texts, str):
|
||||||
|
return "single_string"
|
||||||
|
|
||||||
|
if (
|
||||||
|
is_cross_encoder
|
||||||
|
and len(texts) > 0
|
||||||
|
and isinstance(texts[0], list)
|
||||||
|
and len(texts[0]) == 2
|
||||||
|
):
|
||||||
|
return "cross_encoder_pairs"
|
||||||
|
|
||||||
|
return "batch_strings"
|
||||||
|
|
||||||
|
def _prepare_tokenizer_input(
|
||||||
|
self, texts: Union[str, List[str]], input_format: str
|
||||||
|
) -> Union[List[str], List[List[str]]]:
|
||||||
|
"""Prepare input for the tokenizer based on detected format."""
|
||||||
|
if input_format == "single_string":
|
||||||
|
return [texts] # Wrap single string for batch processing
|
||||||
|
elif input_format == "cross_encoder_pairs":
|
||||||
|
return texts # Already in correct format: [["query", "doc"]]
|
||||||
|
else: # batch_strings
|
||||||
|
return texts # Already in correct format: ["text1", "text2"]
|
||||||
|
|
||||||
|
def _extract_tokenizer_results(
|
||||||
|
self,
|
||||||
|
input_ids: List[List[int]],
|
||||||
|
token_type_ids: Optional[List[List[int]]],
|
||||||
|
input_format: str,
|
||||||
|
original_batch_size: int,
|
||||||
|
) -> Union[
|
||||||
|
Tuple[List[int], Optional[List[int]]],
|
||||||
|
Tuple[List[List[int]], Optional[List[List[int]]]],
|
||||||
|
]:
|
||||||
|
"""Extract results from tokenizer output based on input format."""
|
||||||
|
|
||||||
|
# For single inputs (string or single cross-encoder pair), extract first element
|
||||||
|
if (
|
||||||
|
input_format in ["single_string", "cross_encoder_pairs"]
|
||||||
|
and original_batch_size == 1
|
||||||
|
):
|
||||||
|
single_input_ids = input_ids[0] if input_ids else []
|
||||||
|
single_token_type_ids = token_type_ids[0] if token_type_ids else None
|
||||||
|
return single_input_ids, single_token_type_ids
|
||||||
|
|
||||||
|
# For true batches, return as-is
|
||||||
|
return input_ids, token_type_ids
|
||||||
|
|
||||||
|
async def _tokenize_texts(
|
||||||
|
self, texts: Union[str, List[str]], is_cross_encoder: bool = False
|
||||||
|
) -> Union[
|
||||||
|
Tuple[List[int], Optional[List[int]]],
|
||||||
|
Tuple[List[List[int]], Optional[List[List[int]]]],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Tokenize text(s) using the appropriate tokenizer strategy.
|
||||||
|
|
||||||
|
This method handles multiple input formats and chooses between async dynamic
|
||||||
|
batch tokenizer (for single texts only) and regular tokenizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: Text input in various formats:
|
||||||
|
|
||||||
|
Regular cases:
|
||||||
|
- Single string: "How are you?"
|
||||||
|
- Batch of strings: ["Hello", "World", "How are you?"]
|
||||||
|
|
||||||
|
Cross-encoder cases (sentence pairs for similarity/ranking):
|
||||||
|
- Single pair: [["query text", "document text"]]
|
||||||
|
- Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
||||||
|
|
||||||
|
is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
|
||||||
|
Enables proper handling of sentence pairs with segment IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Single input cases:
|
||||||
|
Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
|
||||||
|
Example: ([101, 2129, 102], [0, 0, 0]) for single text
|
||||||
|
Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
|
||||||
|
|
||||||
|
Batch input cases:
|
||||||
|
Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
|
||||||
|
Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
|
||||||
|
|
||||||
|
Note: token_type_ids is None unless is_cross_encoder=True.
|
||||||
|
"""
|
||||||
|
if not texts or self.tokenizer is None:
|
||||||
|
raise ValueError("texts cannot be empty and tokenizer must be initialized")
|
||||||
|
|
||||||
|
# Step 1: Detect input format and prepare for tokenization
|
||||||
|
input_format = self._detect_input_format(texts, is_cross_encoder)
|
||||||
|
tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
|
||||||
|
original_batch_size = len(texts) if not isinstance(texts, str) else 1
|
||||||
|
|
||||||
|
# Step 2: Set up tokenizer arguments
|
||||||
|
tokenizer_kwargs = (
|
||||||
|
{"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Choose tokenization strategy
|
||||||
|
use_async_tokenizer = (
|
||||||
|
self.async_dynamic_batch_tokenizer is not None
|
||||||
|
and input_format == "single_string"
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_async_tokenizer:
|
||||||
|
logger.debug("Using async dynamic batch tokenizer for single text")
|
||||||
|
result = await self.async_dynamic_batch_tokenizer.encode(
|
||||||
|
tokenizer_input[0], **tokenizer_kwargs
|
||||||
|
)
|
||||||
|
# Convert to batch format for consistency
|
||||||
|
input_ids = [result["input_ids"]]
|
||||||
|
token_type_ids = (
|
||||||
|
[result["token_type_ids"]]
|
||||||
|
if is_cross_encoder and result.get("token_type_ids")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
|
||||||
|
encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
|
||||||
|
input_ids = encoded["input_ids"]
|
||||||
|
token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
|
||||||
|
|
||||||
|
# Step 4: Extract results based on input format
|
||||||
|
return self._extract_tokenizer_results(
|
||||||
|
input_ids, token_type_ids, input_format, original_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
async def _tokenize_one_request(
|
async def _tokenize_one_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -400,14 +551,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
"accept text prompts. Please provide input_ids or re-initialize "
|
"accept text prompts. Please provide input_ids or re-initialize "
|
||||||
"the engine with skip_tokenizer_init=False."
|
"the engine with skip_tokenizer_init=False."
|
||||||
)
|
)
|
||||||
encoded = self.tokenizer(
|
|
||||||
input_text, return_token_type_ids=is_cross_encoder_request
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = encoded["input_ids"]
|
input_ids, token_type_ids = await self._tokenize_texts(
|
||||||
if is_cross_encoder_request:
|
input_text, is_cross_encoder_request
|
||||||
input_ids = encoded["input_ids"][0]
|
)
|
||||||
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
|
||||||
|
|
||||||
if self.mm_processor and obj.contains_mm_input():
|
if self.mm_processor and obj.contains_mm_input():
|
||||||
if not isinstance(obj.image_data, list):
|
if not isinstance(obj.image_data, list):
|
||||||
@@ -582,17 +729,27 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
requests = [obj[i] for i in range(batch_size)]
|
requests = [obj[i] for i in range(batch_size)]
|
||||||
texts = [req.text for req in requests]
|
texts = [req.text for req in requests]
|
||||||
|
|
||||||
# Batch tokenize all texts
|
# Check if any request is a cross-encoder request
|
||||||
encoded = self.tokenizer(texts)
|
is_cross_encoder_request = any(
|
||||||
input_ids_list = encoded["input_ids"]
|
isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
|
||||||
|
for req in requests
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch tokenize all texts using unified method
|
||||||
|
input_ids_list, token_type_ids_list = await self._tokenize_texts(
|
||||||
|
texts, is_cross_encoder_request
|
||||||
|
)
|
||||||
|
|
||||||
# Process all requests
|
# Process all requests
|
||||||
tokenized_objs = []
|
tokenized_objs = []
|
||||||
for i, req in enumerate(requests):
|
for i, req in enumerate(requests):
|
||||||
self._validate_one_request(obj[i], input_ids_list[i])
|
self._validate_one_request(obj[i], input_ids_list[i])
|
||||||
|
token_type_ids = (
|
||||||
|
token_type_ids_list[i] if token_type_ids_list is not None else None
|
||||||
|
)
|
||||||
tokenized_objs.append(
|
tokenized_objs.append(
|
||||||
self._create_tokenized_object(
|
self._create_tokenized_object(
|
||||||
req, req.text, input_ids_list[i], None, None
|
req, req.text, input_ids_list[i], None, None, token_type_ids
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.debug(f"Completed batch processing for {batch_size} requests")
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
||||||
|
|||||||
@@ -373,6 +373,11 @@ class ServerArgs:
|
|||||||
scheduler_recv_interval: int = 1
|
scheduler_recv_interval: int = 1
|
||||||
numa_node: Optional[List[int]] = None
|
numa_node: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# Dynamic batch tokenizer
|
||||||
|
enable_dynamic_batch_tokenizer: bool = False
|
||||||
|
dynamic_batch_tokenizer_batch_size: int = 32
|
||||||
|
dynamic_batch_tokenizer_batch_timeout: float = 0.002
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
debug_tensor_dump_input_file: Optional[str] = None
|
debug_tensor_dump_input_file: Optional[str] = None
|
||||||
@@ -874,6 +879,13 @@ class ServerArgs:
|
|||||||
self.disable_cuda_graph = True
|
self.disable_cuda_graph = True
|
||||||
logger.warning("Cuda graph is disabled for prefill server")
|
logger.warning("Cuda graph is disabled for prefill server")
|
||||||
|
|
||||||
|
# Validation: prevent both tokenizer batching features from being enabled
|
||||||
|
if self.enable_tokenizer_batch_encode and self.enable_dynamic_batch_tokenizer:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot enable both --enable-tokenizer-batch-encode and --enable-dynamic-batch-tokenizer. "
|
||||||
|
"Please choose one tokenizer batching approach."
|
||||||
|
)
|
||||||
|
|
||||||
# Propagate env vars
|
# Propagate env vars
|
||||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||||
"1" if self.enable_torch_compile else "0"
|
"1" if self.enable_torch_compile else "0"
|
||||||
@@ -2162,6 +2174,23 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
|
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-dynamic-batch-tokenizer",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable async dynamic batch tokenizer for improved performance when multiple requests arrive concurrently.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-batch-tokenizer-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.dynamic_batch_tokenizer_batch_size,
|
||||||
|
help="[Only used if --enable-dynamic-batch-tokenizer is set] Maximum batch size for dynamic batch tokenizer.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-batch-tokenizer-batch-timeout",
|
||||||
|
type=float,
|
||||||
|
default=ServerArgs.dynamic_batch_tokenizer_batch_timeout,
|
||||||
|
help="[Only used if --enable-dynamic-batch-tokenizer is set] Timeout in seconds for batching tokenization requests.",
|
||||||
|
)
|
||||||
|
|
||||||
# PD disaggregation
|
# PD disaggregation
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
295
test/srt/test_async_dynamic_batch_tokenizer.py
Normal file
295
test/srt/test_async_dynamic_batch_tokenizer.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for AsyncDynamicbatchTokenizer.
|
||||||
|
|
||||||
|
Tests the async dynamic batching functionality for tokenization,
|
||||||
|
including batch efficiency, timeout handling, and error cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncDynamicbatchTokenizer:
|
||||||
|
"""Test suite for AsyncDynamicbatchTokenizer."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tokenizer(self):
|
||||||
|
"""Create a mock tokenizer that behaves like HuggingFace tokenizer."""
|
||||||
|
|
||||||
|
def mock_encode(texts, **kwargs):
|
||||||
|
is_single = isinstance(texts, str)
|
||||||
|
if is_single:
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
# Simulate tokenization - convert text to mock token ids
|
||||||
|
input_ids = []
|
||||||
|
token_type_ids = []
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Simple mock: text length determines number of tokens
|
||||||
|
tokens = [i for i in range(len(text.split()))]
|
||||||
|
input_ids.append(tokens)
|
||||||
|
|
||||||
|
if kwargs.get("return_token_type_ids", False):
|
||||||
|
token_type_ids.append([0] * len(tokens))
|
||||||
|
|
||||||
|
result = {"input_ids": input_ids}
|
||||||
|
if kwargs.get("return_token_type_ids", False):
|
||||||
|
result["token_type_ids"] = token_type_ids
|
||||||
|
|
||||||
|
# For single inputs, return individual result (not wrapped in a list)
|
||||||
|
if is_single:
|
||||||
|
result = {"input_ids": input_ids[0]}
|
||||||
|
if kwargs.get("return_token_type_ids", False):
|
||||||
|
result["token_type_ids"] = token_type_ids[0]
|
||||||
|
|
||||||
|
# Create a proper BatchEncoding-like object that supports dict operations
|
||||||
|
class MockBatchEncoding(dict):
|
||||||
|
def __init__(self, data):
|
||||||
|
super().__init__(data)
|
||||||
|
for key, value in data.items():
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return MockBatchEncoding(result)
|
||||||
|
|
||||||
|
# Return the function directly - the AsyncDynamicbatchTokenizer will call it
|
||||||
|
return mock_encode
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def async_tokenizer(self, mock_tokenizer):
|
||||||
|
"""Create AsyncDynamicbatchTokenizer instance."""
|
||||||
|
return AsyncDynamicbatchTokenizer(
|
||||||
|
tokenizer=mock_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_request(self, async_tokenizer):
|
||||||
|
"""Test tokenizing a single request."""
|
||||||
|
text = "hello world"
|
||||||
|
result = await async_tokenizer.encode(text)
|
||||||
|
|
||||||
|
assert "input_ids" in result
|
||||||
|
assert result["input_ids"] == [0, 1] # 2 words -> 2 tokens
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_request_with_token_type_ids(self, async_tokenizer):
|
||||||
|
"""Test tokenizing with token type IDs."""
|
||||||
|
text = "hello world"
|
||||||
|
result = await async_tokenizer.encode(text, return_token_type_ids=True)
|
||||||
|
|
||||||
|
assert "input_ids" in result
|
||||||
|
assert "token_type_ids" in result
|
||||||
|
assert result["input_ids"] == [0, 1]
|
||||||
|
assert result["token_type_ids"] == [0, 0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_requests_same_kwargs(self, async_tokenizer):
|
||||||
|
"""Test that concurrent requests with same kwargs get batched."""
|
||||||
|
texts = ["hello world", "how are you", "fine thanks", "good morning"]
|
||||||
|
|
||||||
|
# Start all requests concurrently
|
||||||
|
tasks = [async_tokenizer.encode(text) for text in texts]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Verify all results
|
||||||
|
assert len(results) == 4
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
assert "input_ids" in result
|
||||||
|
expected_tokens = list(range(len(texts[i].split())))
|
||||||
|
assert result["input_ids"] == expected_tokens
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_requests_different_kwargs(self, async_tokenizer):
|
||||||
|
"""Test that requests with different kwargs are processed individually."""
|
||||||
|
text1 = "hello world"
|
||||||
|
text2 = "how are you"
|
||||||
|
|
||||||
|
# One with token_type_ids, one without
|
||||||
|
task1 = async_tokenizer.encode(text1, return_token_type_ids=True)
|
||||||
|
task2 = async_tokenizer.encode(text2)
|
||||||
|
|
||||||
|
result1, result2 = await asyncio.gather(task1, task2)
|
||||||
|
|
||||||
|
# First result should have token_type_ids
|
||||||
|
assert "input_ids" in result1
|
||||||
|
assert "token_type_ids" in result1
|
||||||
|
assert result1["input_ids"] == [0, 1]
|
||||||
|
assert result1["token_type_ids"] == [0, 0]
|
||||||
|
|
||||||
|
# Second result should not have token_type_ids
|
||||||
|
assert "input_ids" in result2
|
||||||
|
assert "token_type_ids" not in result2
|
||||||
|
assert result2["input_ids"] == [0, 1, 2]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_timeout(self, async_tokenizer):
|
||||||
|
"""Test that batching respects timeout."""
|
||||||
|
# Send first request
|
||||||
|
task1 = asyncio.create_task(async_tokenizer.encode("hello world"))
|
||||||
|
|
||||||
|
# Wait longer than batch timeout
|
||||||
|
await asyncio.sleep(0.02) # Longer than 0.01s timeout
|
||||||
|
|
||||||
|
# Send second request
|
||||||
|
task2 = asyncio.create_task(async_tokenizer.encode("how are you"))
|
||||||
|
|
||||||
|
results = await asyncio.gather(task1, task2)
|
||||||
|
|
||||||
|
# Both should complete successfully
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["input_ids"] == [0, 1]
|
||||||
|
assert results[1]["input_ids"] == [0, 1, 2]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_batch_size_limit(self, async_tokenizer):
|
||||||
|
"""Test that batching respects max_batch_size."""
|
||||||
|
# Send more requests than max_batch_size (4)
|
||||||
|
texts = [f"text {i}" for i in range(6)]
|
||||||
|
tasks = [async_tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# All should complete successfully
|
||||||
|
assert len(results) == 6
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
assert "input_ids" in result
|
||||||
|
assert result["input_ids"] == [0, 1] # "text i" -> 2 tokens
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callable_interface(self, async_tokenizer):
|
||||||
|
"""Test that the tokenizer is callable."""
|
||||||
|
text = "hello world"
|
||||||
|
result = await async_tokenizer(text)
|
||||||
|
|
||||||
|
assert "input_ids" in result
|
||||||
|
assert result["input_ids"] == [0, 1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lazy_initialization(self, mock_tokenizer):
|
||||||
|
"""Test that initialization happens lazily."""
|
||||||
|
tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer)
|
||||||
|
|
||||||
|
# Should not be initialized yet
|
||||||
|
assert not tokenizer._initialized
|
||||||
|
|
||||||
|
# First encode should initialize
|
||||||
|
await tokenizer.encode("hello")
|
||||||
|
|
||||||
|
# Should now be initialized
|
||||||
|
assert tokenizer._initialized
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_handling_in_tokenizer(self, mock_tokenizer):
|
||||||
|
"""Test error handling when tokenizer fails."""
|
||||||
|
|
||||||
|
# Create a new async tokenizer with a failing tokenizer
|
||||||
|
def failing_tokenizer(*args, **kwargs):
|
||||||
|
raise ValueError("Tokenizer error")
|
||||||
|
|
||||||
|
async_tokenizer = AsyncDynamicbatchTokenizer(
|
||||||
|
tokenizer=failing_tokenizer, max_batch_size=4, batch_wait_timeout_s=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Tokenizer error"):
|
||||||
|
await async_tokenizer.encode("hello world")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_processing_logs(self, async_tokenizer, caplog):
|
||||||
|
"""Test that batch processing logs are generated."""
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
|
||||||
|
# Send multiple requests to trigger batching
|
||||||
|
tasks = [
|
||||||
|
async_tokenizer.encode("hello world"),
|
||||||
|
async_tokenizer.encode("how are you"),
|
||||||
|
]
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Should have batch processing log
|
||||||
|
assert any(
|
||||||
|
"Processing dynamic batch of size" in record.message
|
||||||
|
for record in caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_queue_immediate_processing(self, async_tokenizer):
|
||||||
|
"""Test that single requests are processed immediately when queue is empty."""
|
||||||
|
start_time = time.time()
|
||||||
|
result = await async_tokenizer.encode("hello world")
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
# Should complete quickly (much less than batch timeout)
|
||||||
|
assert end_time - start_time < 0.005 # 5ms should be plenty
|
||||||
|
assert result["input_ids"] == [0, 1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_tokenizer_integration(self):
|
||||||
|
"""Test with a real HuggingFace tokenizer."""
|
||||||
|
try:
|
||||||
|
# Use a small, fast tokenizer for testing
|
||||||
|
real_tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
async_tokenizer = AsyncDynamicbatchTokenizer(
|
||||||
|
tokenizer=real_tokenizer, max_batch_size=2, batch_wait_timeout_s=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
text = "Hello, world!"
|
||||||
|
result = await async_tokenizer.encode(text)
|
||||||
|
|
||||||
|
# Should get actual token IDs
|
||||||
|
assert "input_ids" in result
|
||||||
|
assert isinstance(result["input_ids"], list)
|
||||||
|
assert len(result["input_ids"]) > 0
|
||||||
|
assert all(isinstance(token_id, int) for token_id in result["input_ids"])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.skip(f"Real tokenizer test skipped: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_mixed_requests(self, async_tokenizer):
|
||||||
|
"""Test mixing single and batched requests."""
|
||||||
|
# Start some requests
|
||||||
|
task1 = asyncio.create_task(async_tokenizer.encode("hello"))
|
||||||
|
task2 = asyncio.create_task(async_tokenizer.encode("world"))
|
||||||
|
|
||||||
|
# Wait a bit
|
||||||
|
await asyncio.sleep(0.005)
|
||||||
|
|
||||||
|
# Start more requests
|
||||||
|
task3 = asyncio.create_task(async_tokenizer.encode("how are"))
|
||||||
|
task4 = asyncio.create_task(async_tokenizer.encode("you doing"))
|
||||||
|
|
||||||
|
results = await asyncio.gather(task1, task2, task3, task4)
|
||||||
|
|
||||||
|
# All should complete successfully
|
||||||
|
assert len(results) == 4
|
||||||
|
for result in results:
|
||||||
|
assert "input_ids" in result
|
||||||
|
assert isinstance(result["input_ids"], list)
|
||||||
|
|
||||||
|
def test_cleanup_on_destruction(self, mock_tokenizer):
|
||||||
|
"""Test that resources are cleaned up properly."""
|
||||||
|
tokenizer = AsyncDynamicbatchTokenizer(mock_tokenizer)
|
||||||
|
|
||||||
|
# Mock the executor and task
|
||||||
|
tokenizer._executor = Mock()
|
||||||
|
tokenizer._batcher_task = Mock()
|
||||||
|
tokenizer._batcher_task.done.return_value = False
|
||||||
|
|
||||||
|
# Call destructor
|
||||||
|
tokenizer.__del__()
|
||||||
|
|
||||||
|
# Should cancel task and shutdown executor
|
||||||
|
tokenizer._batcher_task.cancel.assert_called_once()
|
||||||
|
tokenizer._executor.shutdown.assert_called_once_with(wait=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
379
test/srt/test_tokenizer_manager.py
Normal file
379
test/srt/test_tokenizer_manager.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for TokenizerManager helper methods.
|
||||||
|
|
||||||
|
This tests the refactored tokenization functionality including input format detection,
|
||||||
|
tokenizer input preparation, and result extraction logic.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_tokenizer_manager.TestInputFormatDetection
|
||||||
|
python3 -m unittest test_tokenizer_manager.TestTokenizerInputPreparation
|
||||||
|
python3 -m unittest test_tokenizer_manager.TestTokenizerResultExtraction
|
||||||
|
python3 -m unittest test_tokenizer_manager.TestTokenizerManagerIntegration
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
|
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
|
||||||
|
|
||||||
|
class TestInputFormatDetection(unittest.TestCase):
|
||||||
|
"""Test cases for _detect_input_format method."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||||
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
self.port_args = PortArgs.init_new(self.server_args)
|
||||||
|
|
||||||
|
with patch("zmq.asyncio.Context"), patch(
|
||||||
|
"sglang.srt.utils.get_zmq_socket"
|
||||||
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||||
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||||
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||||
|
|
||||||
|
def test_detect_single_string(self):
|
||||||
|
"""Test detection of single string input."""
|
||||||
|
text = "Hello world"
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
text, is_cross_encoder=False
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "single_string")
|
||||||
|
|
||||||
|
def test_detect_single_string_cross_encoder_disabled(self):
|
||||||
|
"""Test single string with cross_encoder disabled still returns single_string."""
|
||||||
|
text = "Hello world"
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
text, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "single_string")
|
||||||
|
|
||||||
|
def test_detect_batch_strings(self):
|
||||||
|
"""Test detection of batch string inputs."""
|
||||||
|
texts = ["Hello", "World", "How are you?"]
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=False
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "batch_strings")
|
||||||
|
|
||||||
|
def test_detect_batch_strings_cross_encoder_disabled(self):
|
||||||
|
"""Test batch strings with cross_encoder disabled."""
|
||||||
|
texts = ["Hello", "World"]
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "batch_strings")
|
||||||
|
|
||||||
|
def test_detect_cross_encoder_single_pair(self):
|
||||||
|
"""Test detection of cross-encoder single pair."""
|
||||||
|
texts = [["query text", "document text"]]
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "cross_encoder_pairs")
|
||||||
|
|
||||||
|
def test_detect_cross_encoder_multiple_pairs(self):
|
||||||
|
"""Test detection of cross-encoder multiple pairs."""
|
||||||
|
texts = [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "cross_encoder_pairs")
|
||||||
|
|
||||||
|
def test_detect_cross_encoder_disabled_with_pairs(self):
|
||||||
|
"""Test pairs with cross_encoder disabled should return batch_strings."""
|
||||||
|
texts = [["query", "document"]]
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=False
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "batch_strings")
|
||||||
|
|
||||||
|
def test_detect_empty_list(self):
|
||||||
|
"""Test detection with empty list."""
|
||||||
|
texts = []
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "batch_strings")
|
||||||
|
|
||||||
|
def test_detect_malformed_cross_encoder_pairs(self):
|
||||||
|
"""Test malformed cross-encoder pairs (not length 2)."""
|
||||||
|
texts = [["query only"]] # Single element, not a pair
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "batch_strings")
|
||||||
|
|
||||||
|
texts = [["query", "doc", "extra"]] # Three elements, not a pair
|
||||||
|
result = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(result, "batch_strings")
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizerInputPreparation(unittest.TestCase):
|
||||||
|
"""Test cases for _prepare_tokenizer_input method."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||||
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
self.port_args = PortArgs.init_new(self.server_args)
|
||||||
|
|
||||||
|
with patch("zmq.asyncio.Context"), patch(
|
||||||
|
"sglang.srt.utils.get_zmq_socket"
|
||||||
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||||
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||||
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||||
|
|
||||||
|
def test_prepare_single_string_input(self):
|
||||||
|
"""Test preparation of single string input."""
|
||||||
|
text = "Hello world"
|
||||||
|
result = self.tokenizer_manager._prepare_tokenizer_input(text, "single_string")
|
||||||
|
self.assertEqual(result, ["Hello world"])
|
||||||
|
|
||||||
|
def test_prepare_batch_strings_input(self):
|
||||||
|
"""Test preparation of batch strings input."""
|
||||||
|
texts = ["Hello", "World", "Test"]
|
||||||
|
result = self.tokenizer_manager._prepare_tokenizer_input(texts, "batch_strings")
|
||||||
|
self.assertEqual(result, ["Hello", "World", "Test"])
|
||||||
|
|
||||||
|
def test_prepare_cross_encoder_pairs_input(self):
|
||||||
|
"""Test preparation of cross-encoder pairs input."""
|
||||||
|
texts = [["query1", "doc1"], ["query2", "doc2"]]
|
||||||
|
result = self.tokenizer_manager._prepare_tokenizer_input(
|
||||||
|
texts, "cross_encoder_pairs"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, [["query1", "doc1"], ["query2", "doc2"]])
|
||||||
|
|
||||||
|
def test_prepare_cross_encoder_single_pair_input(self):
|
||||||
|
"""Test preparation of single cross-encoder pair."""
|
||||||
|
texts = [["query text", "document text"]]
|
||||||
|
result = self.tokenizer_manager._prepare_tokenizer_input(
|
||||||
|
texts, "cross_encoder_pairs"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, [["query text", "document text"]])
|
||||||
|
|
||||||
|
def test_prepare_unknown_input_format(self):
|
||||||
|
"""Test preparation with unknown input format falls back to returning as-is."""
|
||||||
|
texts = ["test"]
|
||||||
|
result = self.tokenizer_manager._prepare_tokenizer_input(
|
||||||
|
texts, "unknown_format"
|
||||||
|
)
|
||||||
|
self.assertEqual(result, ["test"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizerResultExtraction(unittest.TestCase):
|
||||||
|
"""Test cases for _extract_tokenizer_results method."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||||
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
self.port_args = PortArgs.init_new(self.server_args)
|
||||||
|
|
||||||
|
with patch("zmq.asyncio.Context"), patch(
|
||||||
|
"sglang.srt.utils.get_zmq_socket"
|
||||||
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||||
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||||
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||||
|
|
||||||
|
def test_extract_single_string_results(self):
|
||||||
|
"""Test extraction for single string input."""
|
||||||
|
input_ids = [[101, 2129, 102]]
|
||||||
|
token_type_ids = [[0, 0, 0]]
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
input_ids, token_type_ids, "single_string", original_batch_size=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result_input_ids, [101, 2129, 102])
|
||||||
|
self.assertEqual(result_token_type_ids, [0, 0, 0])
|
||||||
|
|
||||||
|
def test_extract_single_cross_encoder_results(self):
|
||||||
|
"""Test extraction for single cross-encoder pair."""
|
||||||
|
input_ids = [[101, 2129, 102, 4068, 102]]
|
||||||
|
token_type_ids = [[0, 0, 0, 1, 1]]
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result_input_ids, [101, 2129, 102, 4068, 102])
|
||||||
|
self.assertEqual(result_token_type_ids, [0, 0, 0, 1, 1])
|
||||||
|
|
||||||
|
def test_extract_batch_results(self):
|
||||||
|
"""Test extraction for batch inputs."""
|
||||||
|
input_ids = [[101, 2129, 102], [101, 4068, 102]]
|
||||||
|
token_type_ids = [[0, 0, 0], [0, 0, 0]]
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
input_ids, token_type_ids, "batch_strings", original_batch_size=2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result_input_ids, [[101, 2129, 102], [101, 4068, 102]])
|
||||||
|
self.assertEqual(result_token_type_ids, [[0, 0, 0], [0, 0, 0]])
|
||||||
|
|
||||||
|
def test_extract_multiple_cross_encoder_results(self):
|
||||||
|
"""Test extraction for multiple cross-encoder pairs."""
|
||||||
|
input_ids = [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
|
||||||
|
token_type_ids = [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]]
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
input_ids, token_type_ids, "cross_encoder_pairs", original_batch_size=2
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
result_input_ids, [[101, 2129, 102, 4068, 102], [101, 7592, 102, 2088, 102]]
|
||||||
|
)
|
||||||
|
self.assertEqual(result_token_type_ids, [[0, 0, 0, 1, 1], [0, 0, 0, 1, 1]])
|
||||||
|
|
||||||
|
def test_extract_empty_results(self):
|
||||||
|
"""Test extraction with empty results."""
|
||||||
|
input_ids = []
|
||||||
|
token_type_ids = None
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
input_ids, token_type_ids, "single_string", original_batch_size=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result_input_ids, [])
|
||||||
|
self.assertIsNone(result_token_type_ids)
|
||||||
|
|
||||||
|
def test_extract_with_none_token_type_ids(self):
|
||||||
|
"""Test extraction when token_type_ids is None."""
|
||||||
|
input_ids = [[101, 2129, 102]]
|
||||||
|
token_type_ids = None
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
input_ids, token_type_ids, "single_string", original_batch_size=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result_input_ids, [101, 2129, 102])
|
||||||
|
self.assertIsNone(result_token_type_ids)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenizerManagerIntegration(unittest.TestCase):
|
||||||
|
"""Integration tests combining multiple helper methods."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Set up test fixtures."""
|
||||||
|
with patch("sglang.srt.utils.get_device", return_value="cpu"):
|
||||||
|
self.server_args = ServerArgs(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
self.port_args = PortArgs.init_new(self.server_args)
|
||||||
|
|
||||||
|
with patch("zmq.asyncio.Context"), patch(
|
||||||
|
"sglang.srt.utils.get_zmq_socket"
|
||||||
|
), patch("sglang.srt.hf_transformers_utils.get_tokenizer") as mock_tokenizer:
|
||||||
|
mock_tokenizer.return_value = Mock(vocab_size=32000)
|
||||||
|
self.tokenizer_manager = TokenizerManager(self.server_args, self.port_args)
|
||||||
|
|
||||||
|
def test_full_workflow_single_string(self):
|
||||||
|
"""Test complete workflow for single string input."""
|
||||||
|
text = "Hello world"
|
||||||
|
|
||||||
|
# Step 1: Detect format
|
||||||
|
input_format = self.tokenizer_manager._detect_input_format(
|
||||||
|
text, is_cross_encoder=False
|
||||||
|
)
|
||||||
|
self.assertEqual(input_format, "single_string")
|
||||||
|
|
||||||
|
# Step 2: Prepare input
|
||||||
|
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
||||||
|
text, input_format
|
||||||
|
)
|
||||||
|
self.assertEqual(tokenizer_input, ["Hello world"])
|
||||||
|
|
||||||
|
# Step 3: Extract results (simulated tokenizer output)
|
||||||
|
mock_input_ids = [[101, 2129, 4248, 102]]
|
||||||
|
mock_token_type_ids = None
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result_input_ids, [101, 2129, 4248, 102])
|
||||||
|
self.assertIsNone(result_token_type_ids)
|
||||||
|
|
||||||
|
def test_full_workflow_cross_encoder_pairs(self):
|
||||||
|
"""Test complete workflow for cross-encoder pairs."""
|
||||||
|
texts = [
|
||||||
|
["How many people live in Berlin?", "Berlin is well known for its museums."]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Step 1: Detect format
|
||||||
|
input_format = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=True
|
||||||
|
)
|
||||||
|
self.assertEqual(input_format, "cross_encoder_pairs")
|
||||||
|
|
||||||
|
# Step 2: Prepare input
|
||||||
|
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
||||||
|
texts, input_format
|
||||||
|
)
|
||||||
|
self.assertEqual(tokenizer_input, texts)
|
||||||
|
|
||||||
|
# Step 3: Extract results (simulated tokenizer output for cross-encoder)
|
||||||
|
mock_input_ids = [[101, 2129, 2116, 102, 4068, 2003, 102]]
|
||||||
|
mock_token_type_ids = [[0, 0, 0, 0, 1, 1, 1]]
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(result_input_ids, [101, 2129, 2116, 102, 4068, 2003, 102])
|
||||||
|
self.assertEqual(result_token_type_ids, [0, 0, 0, 0, 1, 1, 1])
|
||||||
|
|
||||||
|
def test_full_workflow_batch_strings(self):
|
||||||
|
"""Test complete workflow for batch strings."""
|
||||||
|
texts = ["Hello", "World", "Test"]
|
||||||
|
|
||||||
|
# Step 1: Detect format
|
||||||
|
input_format = self.tokenizer_manager._detect_input_format(
|
||||||
|
texts, is_cross_encoder=False
|
||||||
|
)
|
||||||
|
self.assertEqual(input_format, "batch_strings")
|
||||||
|
|
||||||
|
# Step 2: Prepare input
|
||||||
|
tokenizer_input = self.tokenizer_manager._prepare_tokenizer_input(
|
||||||
|
texts, input_format
|
||||||
|
)
|
||||||
|
self.assertEqual(tokenizer_input, ["Hello", "World", "Test"])
|
||||||
|
|
||||||
|
# Step 3: Extract results (simulated tokenizer output)
|
||||||
|
mock_input_ids = [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
|
||||||
|
mock_token_type_ids = None
|
||||||
|
|
||||||
|
result_input_ids, result_token_type_ids = (
|
||||||
|
self.tokenizer_manager._extract_tokenizer_results(
|
||||||
|
mock_input_ids, mock_token_type_ids, input_format, original_batch_size=3
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
result_input_ids, [[101, 7592, 102], [101, 2088, 102], [101, 2774, 102]]
|
||||||
|
)
|
||||||
|
self.assertIsNone(result_token_type_ids)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
Reference in New Issue
Block a user