Adjust overlap event loop (#11507)

This commit is contained in:
Liangsheng Yin
2025-10-14 00:33:19 +08:00
committed by GitHub
parent 9cc1e065f1
commit bfadb5ea5f
5 changed files with 44 additions and 45 deletions

View File

@@ -148,7 +148,7 @@ from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
from sglang.srt.speculative.eagle_info import EagleDraftInput
@@ -212,8 +212,7 @@ class GenerationBatchResult:
# For overlap scheduling
copy_done: Optional[torch.cuda.Event] = None
delay_sample_launch: bool = False
forward_batch: Optional[ForwardBatch] = None
delay_sample_func: Optional[callable] = None
future_indices: Optional[FutureIndices] = None
# FIXME(lsyin): maybe move to <BetterPlace> ?
@@ -1036,17 +1035,16 @@ class Scheduler(
self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
while True:
self.launch_last_batch_sample_if_needed()
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
batch_result = None
if batch:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
batch_result = self.run_batch(batch)
self.result_queue.append((batch.copy(), batch_result))
if self.last_batch:
# Process the results of the last batch
@@ -1056,6 +1054,7 @@ class Scheduler(
# When the server is idle, do self-check and re-init some states
self.self_check_during_idle()
self.launch_batch_sample_if_needed(batch_result)
self.last_batch = batch
@DynamicGradMode()
@@ -2207,8 +2206,6 @@ class Scheduler(
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
self.future_map.resolve_future(model_worker_batch)
if batch.sampling_info.grammars is not None:
model_worker_batch.delay_sample_launch = True
batch_result = self.model_worker.forward_batch_generation(
model_worker_batch
)
@@ -2216,7 +2213,7 @@ class Scheduler(
batch_result.copy_done = torch.get_device_module(
self.device
).Event()
if not model_worker_batch.delay_sample_launch:
if batch_result.delay_sample_func is None:
self.future_map.store_to_map(future_indices, batch_result)
batch_result.copy_to_cpu()
else:
@@ -2280,29 +2277,20 @@ class Scheduler(
ret = EmbeddingBatchResult(embeddings=embeddings)
return ret
def launch_last_batch_sample_if_needed(
self,
def launch_batch_sample_if_needed(
self, batch_result: GenerationBatchResult
) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
if len(self.result_queue) == 0:
return
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_result: GenerationBatchResult
if not tmp_result.delay_sample_launch:
self.result_queue.appendleft((tmp_batch, tmp_result))
# TODO(lsyin): make the delayed sample a default behavior after
# unifying the forward_batch_generation interface (related to spec V2).
if batch_result is None or batch_result.delay_sample_func is None:
return
with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.default_stream)
tmp_result.next_token_ids = self.model_worker.model_runner.sample(
tmp_result.logits_output,
tmp_result.forward_batch,
)
future_indices = tmp_result.future_indices
self.future_map.store_to_map(future_indices, tmp_result)
tmp_result.copy_to_cpu()
self.result_queue.appendleft((tmp_batch, tmp_result))
_batch_result = batch_result.delay_sample_func()
assert _batch_result is batch_result
self.future_map.store_to_map(batch_result.future_indices, batch_result)
batch_result.copy_to_cpu()
def process_batch_result(
self,