[Performance] Batch Send from Tokenizer Manager. (#9436)

This commit is contained in:
Sundara Raman Ramachandran
2025-08-25 10:43:54 -07:00
committed by GitHub
parent 3aec3d4f8b
commit ea0696b924
3 changed files with 117 additions and 6 deletions

View File

@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import (
BatchMultimodalOut,
BatchStrOut,
BatchTokenIDOut,
BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
CloseSessionReqInput,
ConfigureLoggingReq,
EmbeddingReqInput,
@@ -768,6 +770,30 @@ class TokenizerManager:
self.rid_to_state[obj.rid] = state
return state
def _send_batch_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
tokenized_objs: List[
Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
],
created_time: Optional[float] = None,
):
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
else:
batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
self.send_to_scheduler.send_pyobj(batch_req)
# Create states for each individual request in the batch
for i, tokenized_obj in enumerate(tokenized_objs):
tmp_obj = obj[i]
state = ReqState(
[], False, asyncio.Event(), tmp_obj, created_time=created_time
)
self.rid_to_state[tmp_obj.rid] = state
async def _wait_one_response(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -870,10 +896,17 @@ class TokenizerManager:
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
for i, tokenized_obj in enumerate(tokenized_objs):
# Send as a single batched request
self._send_batch_request(obj, tokenized_objs, created_time)
# Set up generators for each request in the batch
for i in range(batch_size):
tmp_obj = obj[i]
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
generators.append(self._wait_one_response(tmp_obj, state, request))
generators.append(
self._wait_one_response(
tmp_obj, self.rid_to_state[tmp_obj.rid], request
)
)
rids.append(tmp_obj.rid)
else:
# Sequential tokenization and processing