From ea0696b92410ca7b19d2b172ead4551a53d33a33 Mon Sep 17 00:00:00 2001 From: Sundara Raman Ramachandran Date: Mon, 25 Aug 2025 10:43:54 -0700 Subject: [PATCH] [Performance] Batch Send from Tokenizer Manager. (#9436) --- python/sglang/srt/managers/io_struct.py | 30 +++++++++++ python/sglang/srt/managers/scheduler.py | 54 +++++++++++++++++-- .../sglang/srt/managers/tokenizer_manager.py | 39 ++++++++++++-- 3 files changed, 117 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 65428e030..256868e4a 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -533,6 +533,21 @@ class TokenizedGenerateReqInput: dp_balance_id: int = -1 +@dataclass +class BatchTokenizedGenerateReqInput: + # The batch of tokenized requests + batch: List[TokenizedGenerateReqInput] + + def __len__(self): + return len(self.batch) + + def __getitem__(self, i): + return self.batch[i] + + def __iter__(self): + return iter(self.batch) + + @dataclass class EmbeddingReqInput: # The input prompt. It can be a single prompt or a batch of prompts. @@ -668,6 +683,21 @@ class TokenizedEmbeddingReqInput: dp_balance_id: int = -1 +@dataclass +class BatchTokenizedEmbeddingReqInput: + # The batch of tokenized embedding requests + batch: List[TokenizedEmbeddingReqInput] + + def __len__(self): + return len(self.batch) + + def __getitem__(self, i): + return self.batch[i] + + def __iter__(self): + return iter(self.batch) + + @dataclass class BatchTokenIDOut: # The request id diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1a82010a2..34c2b164c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.managers.io_struct import ( AbortReq, + BatchTokenizedEmbeddingReqInput, + BatchTokenizedGenerateReqInput, CloseSessionReqInput, ExpertDistributionReq, ExpertDistributionReqOutput, @@ -510,6 +512,8 @@ class Scheduler( [ (TokenizedGenerateReqInput, self.handle_generate_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request), + (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request), (FlushCacheReqInput, self.flush_cache_wrapped), (AbortReq, self.abort_request), (OpenSessionReqInput, self.open_session), @@ -1018,14 +1022,26 @@ class Scheduler( req for req in recv_reqs if isinstance( - req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + req, + ( + TokenizedGenerateReqInput, + TokenizedEmbeddingReqInput, + BatchTokenizedGenerateReqInput, + BatchTokenizedEmbeddingReqInput, + ), ) ] control_reqs = [ req for req in recv_reqs if not isinstance( - req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + req, + ( + TokenizedGenerateReqInput, + TokenizedEmbeddingReqInput, + BatchTokenizedGenerateReqInput, + BatchTokenizedEmbeddingReqInput, + ), ) ] else: @@ -1253,6 +1269,17 @@ class Scheduler( else: self._add_request_to_queue(req) + def handle_batch_generate_request( + self, + recv_req: BatchTokenizedGenerateReqInput, + ): + """Handle optimized batch generate request.""" + logger.debug(f"Processing batch generate request with {len(recv_req)} requests") + + # Process each request in the batch + for tokenized_req in recv_req: + self.handle_generate_request(tokenized_req) + def _add_request_to_queue(self, req: Req): req.queue_time_start = time.perf_counter() if self.disaggregation_mode == DisaggregationMode.PREFILL: @@ -1335,6 +1362,19 @@ class Scheduler( req.logprob_start_len = len(req.origin_input_ids) - 1 self._add_request_to_queue(req) + def handle_batch_embedding_request( + self, + recv_req: BatchTokenizedEmbeddingReqInput, + ): + """Handle optimized batch embedding request.""" + logger.debug( + f"Processing batch embedding request with {len(recv_req)} requests" + ) + + # Process each request in the batch + for tokenized_req in recv_req: + self.handle_embedding_request(tokenized_req) + def self_check_during_idle(self): self.check_memory() self.check_tree_cache() @@ -2513,7 +2553,15 @@ def is_health_check_generate_req(recv_req): def is_work_request(recv_req): - return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)) + return isinstance( + recv_req, + ( + TokenizedGenerateReqInput, + TokenizedEmbeddingReqInput, + BatchTokenizedGenerateReqInput, + BatchTokenizedEmbeddingReqInput, + ), + ) def run_scheduler_process( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1161cdf1a..7c09379cd 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import ( BatchMultimodalOut, BatchStrOut, BatchTokenIDOut, + BatchTokenizedEmbeddingReqInput, + BatchTokenizedGenerateReqInput, CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, @@ -768,6 +770,30 @@ class TokenizerManager: self.rid_to_state[obj.rid] = state return state + def _send_batch_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + tokenized_objs: List[ + Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput] + ], + created_time: Optional[float] = None, + ): + """Send a batch of tokenized requests as a single batched request to the scheduler.""" + if isinstance(tokenized_objs[0], TokenizedGenerateReqInput): + batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs) + else: + batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs) + + self.send_to_scheduler.send_pyobj(batch_req) + + # Create states for each individual request in the batch + for i, tokenized_obj in enumerate(tokenized_objs): + tmp_obj = obj[i] + state = ReqState( + [], False, asyncio.Event(), tmp_obj, created_time=created_time + ) + self.rid_to_state[tmp_obj.rid] = state + async def _wait_one_response( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -870,10 +896,17 @@ class TokenizerManager: tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj) - for i, tokenized_obj in enumerate(tokenized_objs): + # Send as a single batched request + self._send_batch_request(obj, tokenized_objs, created_time) + + # Set up generators for each request in the batch + for i in range(batch_size): tmp_obj = obj[i] - state = self._send_one_request(tmp_obj, tokenized_obj, created_time) - generators.append(self._wait_one_response(tmp_obj, state, request)) + generators.append( + self._wait_one_response( + tmp_obj, self.rid_to_state[tmp_obj.rid], request + ) + ) rids.append(tmp_obj.rid) else: # Sequential tokenization and processing