diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3e325ca4d..3cd8cac5f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -759,6 +759,14 @@ class TokenizerManager(TokenizerCommunicatorMixin): """Handle batch tokenization for text inputs only.""" logger.debug(f"Starting batch tokenization for {batch_size} text requests") + # If batch does not have text nothing to tokenize + # so lets construct the return object + if not self._batch_has_text(batch_size, obj): + # All requests already have input_ids, no need to tokenize + return [await self._tokenize_one_request(obj[i]) for i in range(batch_size)] + + self._validate_batch_tokenization_constraints(batch_size, obj) + # Collect requests and texts requests = [obj[i] for i in range(batch_size)] texts = [req.text for req in requests] @@ -808,6 +816,30 @@ class TokenizerManager(TokenizerCommunicatorMixin): "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." ) + def _batch_has_text( + self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput] + ) -> bool: + """Check if any request in the batch contains text input.""" + for i in range(batch_size): + if obj[i].text: + return True + elif self.is_generation and obj[i].contains_mm_input(): + return True + + return False + + def _should_use_batch_tokenization(self, batch_size, requests) -> bool: + """Return True if we should run the tokenizer in batch mode. + + Current policy: + - Respect explicit server flag `enable_tokenizer_batch_encode`. + - Or, if no request has text or multimodal input (all use pre-tokenized input_ids or input_embeds), batch the requests without tokenization. + """ + return batch_size > 0 and ( + self.server_args.enable_tokenizer_batch_encode + or not self._batch_has_text(batch_size, requests) + ) + def _send_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -942,13 +974,8 @@ class TokenizerManager(TokenizerCommunicatorMixin): generators = [] rids = [] if getattr(obj, "parallel_sample_num", 1) == 1: - if self.server_args.enable_tokenizer_batch_encode: - # Validate batch tokenization constraints - self._validate_batch_tokenization_constraints(batch_size, obj) - + if self._should_use_batch_tokenization(batch_size, obj): tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj) - - # Send as a single batched request self._send_batch_request(obj, tokenized_objs, created_time) # Set up generators for each request in the batch