support tokenized batch request (#11091)
This commit is contained in:
@@ -759,6 +759,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
"""Handle batch tokenization for text inputs only."""
|
"""Handle batch tokenization for text inputs only."""
|
||||||
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
logger.debug(f"Starting batch tokenization for {batch_size} text requests")
|
||||||
|
|
||||||
|
# If batch does not have text nothing to tokenize
|
||||||
|
# so lets construct the return object
|
||||||
|
if not self._batch_has_text(batch_size, obj):
|
||||||
|
# All requests already have input_ids, no need to tokenize
|
||||||
|
return [await self._tokenize_one_request(obj[i]) for i in range(batch_size)]
|
||||||
|
|
||||||
|
self._validate_batch_tokenization_constraints(batch_size, obj)
|
||||||
|
|
||||||
# Collect requests and texts
|
# Collect requests and texts
|
||||||
requests = [obj[i] for i in range(batch_size)]
|
requests = [obj[i] for i in range(batch_size)]
|
||||||
texts = [req.text for req in requests]
|
texts = [req.text for req in requests]
|
||||||
@@ -808,6 +816,30 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _batch_has_text(
|
||||||
|
self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
|
||||||
|
) -> bool:
|
||||||
|
"""Check if any request in the batch contains text input."""
|
||||||
|
for i in range(batch_size):
|
||||||
|
if obj[i].text:
|
||||||
|
return True
|
||||||
|
elif self.is_generation and obj[i].contains_mm_input():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _should_use_batch_tokenization(self, batch_size, requests) -> bool:
|
||||||
|
"""Return True if we should run the tokenizer in batch mode.
|
||||||
|
|
||||||
|
Current policy:
|
||||||
|
- Respect explicit server flag `enable_tokenizer_batch_encode`.
|
||||||
|
- Or, if no request has text or multimodal input (all use pre-tokenized input_ids or input_embeds), batch the requests without tokenization.
|
||||||
|
"""
|
||||||
|
return batch_size > 0 and (
|
||||||
|
self.server_args.enable_tokenizer_batch_encode
|
||||||
|
or not self._batch_has_text(batch_size, requests)
|
||||||
|
)
|
||||||
|
|
||||||
def _send_one_request(
|
def _send_one_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -942,13 +974,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
generators = []
|
generators = []
|
||||||
rids = []
|
rids = []
|
||||||
if getattr(obj, "parallel_sample_num", 1) == 1:
|
if getattr(obj, "parallel_sample_num", 1) == 1:
|
||||||
if self.server_args.enable_tokenizer_batch_encode:
|
if self._should_use_batch_tokenization(batch_size, obj):
|
||||||
# Validate batch tokenization constraints
|
|
||||||
self._validate_batch_tokenization_constraints(batch_size, obj)
|
|
||||||
|
|
||||||
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
||||||
|
|
||||||
# Send as a single batched request
|
|
||||||
self._send_batch_request(obj, tokenized_objs, created_time)
|
self._send_batch_request(obj, tokenized_objs, created_time)
|
||||||
|
|
||||||
# Set up generators for each request in the batch
|
# Set up generators for each request in the batch
|
||||||
|
|||||||
Reference in New Issue
Block a user