Adjust overlap event loop (#11507)
This commit is contained in:
@@ -752,7 +752,6 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.last_batch_in_queue = False # last batch is modified in-place, so we need another variable to track if it's extend
|
||||
|
||||
while True:
|
||||
self.launch_last_batch_sample_if_needed()
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
@@ -764,6 +763,7 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
||||
|
||||
batch_result = None
|
||||
if batch:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
@@ -772,25 +772,25 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_mlp_sync_flag:
|
||||
batch_, result = self._prepare_idle_batch_and_run(
|
||||
batch_, batch_result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch_:
|
||||
self.result_queue.append((batch_.copy(), result))
|
||||
self.result_queue.append((batch_.copy(), batch_result))
|
||||
last_batch_in_queue = True
|
||||
else:
|
||||
if prepare_mlp_sync_flag:
|
||||
self.prepare_mlp_sync_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
batch_result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), batch_result))
|
||||
last_batch_in_queue = True
|
||||
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, result = self._prepare_idle_batch_and_run(
|
||||
batch, batch_result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
)
|
||||
if batch:
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
self.result_queue.append((batch.copy(), batch_result))
|
||||
last_batch_in_queue = True
|
||||
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
@@ -798,6 +798,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
self.launch_batch_sample_if_needed(batch_result)
|
||||
|
||||
queue_size = (
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
|
||||
@@ -321,8 +321,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
self.launch_last_batch_sample_if_needed()
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
@@ -334,9 +332,11 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
if require_mlp_sync(self.server_args):
|
||||
batch = self.prepare_mlp_sync_batch(batch)
|
||||
self.cur_batch = batch
|
||||
|
||||
batch_result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
batch_result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), batch_result))
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
@@ -345,6 +345,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
self.launch_batch_sample_if_needed(batch_result)
|
||||
|
||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.self_check_during_idle()
|
||||
|
||||
|
||||
@@ -1907,8 +1907,5 @@ class ModelWorkerBatch:
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
hicache_consumer_index: int = -1
|
||||
|
||||
# Overlap scheduler related
|
||||
delay_sample_launch: bool = False
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
@@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
@@ -212,8 +212,7 @@ class GenerationBatchResult:
|
||||
|
||||
# For overlap scheduling
|
||||
copy_done: Optional[torch.cuda.Event] = None
|
||||
delay_sample_launch: bool = False
|
||||
forward_batch: Optional[ForwardBatch] = None
|
||||
delay_sample_func: Optional[callable] = None
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
|
||||
# FIXME(lsyin): maybe move to <BetterPlace> ?
|
||||
@@ -1036,17 +1035,16 @@ class Scheduler(
|
||||
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
|
||||
|
||||
while True:
|
||||
self.launch_last_batch_sample_if_needed()
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
self.cur_batch = batch
|
||||
|
||||
batch_result = None
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
batch_result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), batch_result))
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
@@ -1056,6 +1054,7 @@ class Scheduler(
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.launch_batch_sample_if_needed(batch_result)
|
||||
self.last_batch = batch
|
||||
|
||||
@DynamicGradMode()
|
||||
@@ -2207,8 +2206,6 @@ class Scheduler(
|
||||
with self.forward_stream_ctx:
|
||||
self.forward_stream.wait_stream(self.default_stream)
|
||||
self.future_map.resolve_future(model_worker_batch)
|
||||
if batch.sampling_info.grammars is not None:
|
||||
model_worker_batch.delay_sample_launch = True
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
@@ -2216,7 +2213,7 @@ class Scheduler(
|
||||
batch_result.copy_done = torch.get_device_module(
|
||||
self.device
|
||||
).Event()
|
||||
if not model_worker_batch.delay_sample_launch:
|
||||
if batch_result.delay_sample_func is None:
|
||||
self.future_map.store_to_map(future_indices, batch_result)
|
||||
batch_result.copy_to_cpu()
|
||||
else:
|
||||
@@ -2280,29 +2277,20 @@ class Scheduler(
|
||||
ret = EmbeddingBatchResult(embeddings=embeddings)
|
||||
return ret
|
||||
|
||||
def launch_last_batch_sample_if_needed(
|
||||
self,
|
||||
def launch_batch_sample_if_needed(
|
||||
self, batch_result: GenerationBatchResult
|
||||
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
|
||||
if len(self.result_queue) == 0:
|
||||
return
|
||||
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
|
||||
tmp_result: GenerationBatchResult
|
||||
if not tmp_result.delay_sample_launch:
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
# TODO(lsyin): make the delayed sample a default behavior after
|
||||
# unifying the forward_batch_generation interface (related to spec V2).
|
||||
if batch_result is None or batch_result.delay_sample_func is None:
|
||||
return
|
||||
|
||||
with self.forward_stream_ctx:
|
||||
self.forward_stream.wait_stream(self.default_stream)
|
||||
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
|
||||
tmp_result.logits_output,
|
||||
tmp_result.forward_batch,
|
||||
)
|
||||
future_indices = tmp_result.future_indices
|
||||
self.future_map.store_to_map(future_indices, tmp_result)
|
||||
tmp_result.copy_to_cpu()
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
_batch_result = batch_result.delay_sample_func()
|
||||
assert _batch_result is batch_result
|
||||
self.future_map.store_to_map(batch_result.future_indices, batch_result)
|
||||
batch_result.copy_to_cpu()
|
||||
|
||||
def process_batch_result(
|
||||
self,
|
||||
|
||||
@@ -168,6 +168,7 @@ class TpModelWorker:
|
||||
)[0]
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||
self.hicache_layer_transfer_counter = None
|
||||
|
||||
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
||||
@@ -266,9 +267,18 @@ class TpModelWorker:
|
||||
# Skip sampling and return logits for target forward
|
||||
return batch_result
|
||||
|
||||
if model_worker_batch.delay_sample_launch:
|
||||
batch_result.delay_sample_launch = True
|
||||
batch_result.forward_batch = forward_batch
|
||||
if (
|
||||
self.enable_overlap
|
||||
and model_worker_batch.sampling_info.grammars is not None
|
||||
):
|
||||
|
||||
def sample_batch_func():
|
||||
batch_result.next_token_ids = self.model_runner.sample(
|
||||
logits_output, forward_batch
|
||||
)
|
||||
return batch_result
|
||||
|
||||
batch_result.delay_sample_func = sample_batch_func
|
||||
return batch_result
|
||||
|
||||
if model_worker_batch.is_prefill_only:
|
||||
|
||||
Reference in New Issue
Block a user