Improve overlap scheduling (#5788)
This commit is contained in:
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def process_batch_result_disagg_prefill(
|
||||
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
else:
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# This is an optimization to reduce the overhead of the prefill check.
|
||||
batch_is_full: bool = False
|
||||
|
||||
# Events
|
||||
launch_done: Optional[threading.Event] = None
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
),
|
||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||
launch_done=self.launch_done,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
|
||||
# Overlap event
|
||||
launch_done: Optional[threading.Event] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
|
||||
@@ -645,6 +645,7 @@ class Scheduler(
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
batch.launch_done = threading.Event()
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
@@ -656,7 +657,7 @@ class Scheduler(
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
@@ -664,7 +665,10 @@ class Scheduler(
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
||||
self.process_batch_result(
|
||||
tmp_batch, tmp_result, batch.launch_done if batch else None
|
||||
)
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
@@ -1417,14 +1421,15 @@ class Scheduler(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
self.process_batch_result_decode(batch, result, launch_done)
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
self.process_batch_result_prefill(batch, result, launch_done)
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_batch_result(result.bid)
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
|
||||
EmbeddingBatchResult,
|
||||
GenerationBatchResult,
|
||||
ScheduleBatch,
|
||||
Scheduler,
|
||||
)
|
||||
|
||||
|
||||
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
|
||||
"""
|
||||
|
||||
def process_batch_result_prefill(
|
||||
self,
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
skip_stream_req = None
|
||||
|
||||
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
|
||||
)
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
logits_output, next_token_ids = (
|
||||
self.tp_worker.resolve_last_batch_result(
|
||||
launch_done,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Move next_token_ids and logprobs to cpu
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
|
||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||
|
||||
def process_batch_result_decode(
|
||||
self,
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
result: GenerationBatchResult,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
logits_output, next_token_ids, bid = (
|
||||
result.logits_output,
|
||||
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if self.enable_overlap:
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
|
||||
launch_done
|
||||
)
|
||||
next_token_logprobs = logits_output.next_token_logprobs
|
||||
elif batch.spec_algorithm.is_none():
|
||||
# spec decoding handles output logprobs inside verify process.
|
||||
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
|
||||
self.log_decode_stats()
|
||||
|
||||
def add_input_logprob_return_values(
|
||||
self,
|
||||
self: Scheduler,
|
||||
i: int,
|
||||
req: Req,
|
||||
output: LogitsProcessorOutput,
|
||||
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
|
||||
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
||||
|
||||
def add_logprob_return_values(
|
||||
self,
|
||||
self: Scheduler,
|
||||
i: int,
|
||||
req: Req,
|
||||
pt: int,
|
||||
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
|
||||
return num_input_logprobs
|
||||
|
||||
def stream_output(
|
||||
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
||||
self: Scheduler,
|
||||
reqs: List[Req],
|
||||
return_logprob: bool,
|
||||
skip_req: Optional[Req] = None,
|
||||
):
|
||||
"""Stream the output to detokenizer."""
|
||||
if self.is_generation:
|
||||
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
|
||||
self.stream_output_embedding(reqs)
|
||||
|
||||
def stream_output_generation(
|
||||
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
||||
self: Scheduler,
|
||||
reqs: List[Req],
|
||||
return_logprob: bool,
|
||||
skip_req: Optional[Req] = None,
|
||||
):
|
||||
rids = []
|
||||
finished_reasons: List[BaseFinishReason] = []
|
||||
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
|
||||
)
|
||||
)
|
||||
|
||||
def stream_output_embedding(self, reqs: List[Req]):
|
||||
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
|
||||
rids = []
|
||||
finished_reasons: List[BaseFinishReason] = []
|
||||
|
||||
|
||||
@@ -170,13 +170,13 @@ class TpModelWorker:
|
||||
def forward_batch_generation(
|
||||
self,
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
skip_sample: bool = False,
|
||||
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
if launch_done:
|
||||
launch_done.set()
|
||||
|
||||
if model_worker_batch.launch_done is not None:
|
||||
model_worker_batch.launch_done.set()
|
||||
|
||||
if skip_sample:
|
||||
next_token_ids = None
|
||||
|
||||
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
|
||||
batch_pt += 1
|
||||
|
||||
# Create event
|
||||
self.launch_done = threading.Event()
|
||||
copy_done = torch.get_device_module(self.device).Event()
|
||||
|
||||
# Resolve future tokens in the input
|
||||
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
|
||||
|
||||
# Run forward
|
||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||
model_worker_batch, self.launch_done
|
||||
model_worker_batch
|
||||
)
|
||||
|
||||
# Update the future token ids map
|
||||
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
|
||||
|
||||
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
||||
|
||||
def resolve_batch_result(self, bid: int):
|
||||
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
||||
"""
|
||||
This function is called to resolve the last batch result and
|
||||
wait for the current batch to be launched. Used in overlap mode.
|
||||
"""
|
||||
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
||||
|
||||
if launch_done is not None:
|
||||
launch_done.wait()
|
||||
copy_done.synchronize()
|
||||
self.launch_done.wait()
|
||||
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
|
||||
Reference in New Issue
Block a user