[Performance] Batch Send from Tokenizer Manager. (#9436)
This commit is contained in:
committed by
GitHub
parent
3aec3d4f8b
commit
ea0696b924
@@ -533,6 +533,21 @@ class TokenizedGenerateReqInput:
|
|||||||
dp_balance_id: int = -1
|
dp_balance_id: int = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchTokenizedGenerateReqInput:
|
||||||
|
# The batch of tokenized requests
|
||||||
|
batch: List[TokenizedGenerateReqInput]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.batch[i]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.batch)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput:
|
||||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||||
@@ -668,6 +683,21 @@ class TokenizedEmbeddingReqInput:
|
|||||||
dp_balance_id: int = -1
|
dp_balance_id: int = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchTokenizedEmbeddingReqInput:
|
||||||
|
# The batch of tokenized embedding requests
|
||||||
|
batch: List[TokenizedEmbeddingReqInput]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.batch[i]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.batch)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatchTokenIDOut:
|
class BatchTokenIDOut:
|
||||||
# The request id
|
# The request id
|
||||||
|
|||||||
@@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|||||||
from sglang.srt.layers.moe import initialize_moe_config
|
from sglang.srt.layers.moe import initialize_moe_config
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
|
BatchTokenizedEmbeddingReqInput,
|
||||||
|
BatchTokenizedGenerateReqInput,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ExpertDistributionReq,
|
ExpertDistributionReq,
|
||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
@@ -510,6 +512,8 @@ class Scheduler(
|
|||||||
[
|
[
|
||||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||||
|
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
|
||||||
|
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
|
||||||
(FlushCacheReqInput, self.flush_cache_wrapped),
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
||||||
(AbortReq, self.abort_request),
|
(AbortReq, self.abort_request),
|
||||||
(OpenSessionReqInput, self.open_session),
|
(OpenSessionReqInput, self.open_session),
|
||||||
@@ -1018,14 +1022,26 @@ class Scheduler(
|
|||||||
req
|
req
|
||||||
for req in recv_reqs
|
for req in recv_reqs
|
||||||
if isinstance(
|
if isinstance(
|
||||||
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
req,
|
||||||
|
(
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
|
BatchTokenizedGenerateReqInput,
|
||||||
|
BatchTokenizedEmbeddingReqInput,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
control_reqs = [
|
control_reqs = [
|
||||||
req
|
req
|
||||||
for req in recv_reqs
|
for req in recv_reqs
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
req,
|
||||||
|
(
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
|
BatchTokenizedGenerateReqInput,
|
||||||
|
BatchTokenizedEmbeddingReqInput,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
@@ -1253,6 +1269,17 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
|
|
||||||
|
def handle_batch_generate_request(
|
||||||
|
self,
|
||||||
|
recv_req: BatchTokenizedGenerateReqInput,
|
||||||
|
):
|
||||||
|
"""Handle optimized batch generate request."""
|
||||||
|
logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
|
||||||
|
|
||||||
|
# Process each request in the batch
|
||||||
|
for tokenized_req in recv_req:
|
||||||
|
self.handle_generate_request(tokenized_req)
|
||||||
|
|
||||||
def _add_request_to_queue(self, req: Req):
|
def _add_request_to_queue(self, req: Req):
|
||||||
req.queue_time_start = time.perf_counter()
|
req.queue_time_start = time.perf_counter()
|
||||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
@@ -1335,6 +1362,19 @@ class Scheduler(
|
|||||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||||
self._add_request_to_queue(req)
|
self._add_request_to_queue(req)
|
||||||
|
|
||||||
|
def handle_batch_embedding_request(
|
||||||
|
self,
|
||||||
|
recv_req: BatchTokenizedEmbeddingReqInput,
|
||||||
|
):
|
||||||
|
"""Handle optimized batch embedding request."""
|
||||||
|
logger.debug(
|
||||||
|
f"Processing batch embedding request with {len(recv_req)} requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each request in the batch
|
||||||
|
for tokenized_req in recv_req:
|
||||||
|
self.handle_embedding_request(tokenized_req)
|
||||||
|
|
||||||
def self_check_during_idle(self):
|
def self_check_during_idle(self):
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
self.check_tree_cache()
|
self.check_tree_cache()
|
||||||
@@ -2513,7 +2553,15 @@ def is_health_check_generate_req(recv_req):
|
|||||||
|
|
||||||
|
|
||||||
def is_work_request(recv_req):
|
def is_work_request(recv_req):
|
||||||
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
return isinstance(
|
||||||
|
recv_req,
|
||||||
|
(
|
||||||
|
TokenizedGenerateReqInput,
|
||||||
|
TokenizedEmbeddingReqInput,
|
||||||
|
BatchTokenizedGenerateReqInput,
|
||||||
|
BatchTokenizedEmbeddingReqInput,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_scheduler_process(
|
def run_scheduler_process(
|
||||||
|
|||||||
@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchMultimodalOut,
|
BatchMultimodalOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
BatchTokenizedEmbeddingReqInput,
|
||||||
|
BatchTokenizedGenerateReqInput,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ConfigureLoggingReq,
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
@@ -768,6 +770,30 @@ class TokenizerManager:
|
|||||||
self.rid_to_state[obj.rid] = state
|
self.rid_to_state[obj.rid] = state
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def _send_batch_request(
|
||||||
|
self,
|
||||||
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
|
tokenized_objs: List[
|
||||||
|
Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
|
||||||
|
],
|
||||||
|
created_time: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""Send a batch of tokenized requests as a single batched request to the scheduler."""
|
||||||
|
if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
|
||||||
|
batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
|
||||||
|
else:
|
||||||
|
batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
|
||||||
|
|
||||||
|
self.send_to_scheduler.send_pyobj(batch_req)
|
||||||
|
|
||||||
|
# Create states for each individual request in the batch
|
||||||
|
for i, tokenized_obj in enumerate(tokenized_objs):
|
||||||
|
tmp_obj = obj[i]
|
||||||
|
state = ReqState(
|
||||||
|
[], False, asyncio.Event(), tmp_obj, created_time=created_time
|
||||||
|
)
|
||||||
|
self.rid_to_state[tmp_obj.rid] = state
|
||||||
|
|
||||||
async def _wait_one_response(
|
async def _wait_one_response(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -870,10 +896,17 @@ class TokenizerManager:
|
|||||||
|
|
||||||
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
|
||||||
|
|
||||||
for i, tokenized_obj in enumerate(tokenized_objs):
|
# Send as a single batched request
|
||||||
|
self._send_batch_request(obj, tokenized_objs, created_time)
|
||||||
|
|
||||||
|
# Set up generators for each request in the batch
|
||||||
|
for i in range(batch_size):
|
||||||
tmp_obj = obj[i]
|
tmp_obj = obj[i]
|
||||||
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
generators.append(
|
||||||
generators.append(self._wait_one_response(tmp_obj, state, request))
|
self._wait_one_response(
|
||||||
|
tmp_obj, self.rid_to_state[tmp_obj.rid], request
|
||||||
|
)
|
||||||
|
)
|
||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# Sequential tokenization and processing
|
# Sequential tokenization and processing
|
||||||
|
|||||||
Reference in New Issue
Block a user