Improve overlap scheduling (#5788)
This commit is contained in:
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
@@ -256,7 +257,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
self.running_batch.batch_is_full = False
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
def process_batch_result_disagg_prefill(
|
def process_batch_result_disagg_prefill(
|
||||||
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
self: Scheduler,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
result: GenerationBatchResult,
|
||||||
|
launch_done: Optional[threading.Event] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||||
@@ -280,7 +284,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
# wait
|
# wait
|
||||||
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
else:
|
else:
|
||||||
next_token_ids = result.next_token_ids.tolist()
|
next_token_ids = result.next_token_ids.tolist()
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
# This is an optimization to reduce the overhead of the prefill check.
|
# This is an optimization to reduce the overhead of the prefill check.
|
||||||
batch_is_full: bool = False
|
batch_is_full: bool = False
|
||||||
|
|
||||||
|
# Events
|
||||||
|
launch_done: Optional[threading.Event] = None
|
||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo = None
|
sampling_info: SamplingBatchInfo = None
|
||||||
next_batch_sampling_info: SamplingBatchInfo = None
|
next_batch_sampling_info: SamplingBatchInfo = None
|
||||||
@@ -1565,6 +1569,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||||
|
launch_done=self.launch_done,
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
|
|||||||
# If set, the output of the batch contains the hidden states of the run.
|
# If set, the output of the batch contains the hidden states of the run.
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
|
|
||||||
|
# Overlap event
|
||||||
|
launch_done: Optional[threading.Event] = None
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def write_req_to_token_pool_triton(
|
def write_req_to_token_pool_triton(
|
||||||
|
|||||||
@@ -645,6 +645,7 @@ class Scheduler(
|
|||||||
self.cur_batch = batch
|
self.cur_batch = batch
|
||||||
|
|
||||||
if batch:
|
if batch:
|
||||||
|
batch.launch_done = threading.Event()
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.result_queue.append((batch.copy(), result))
|
self.result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
@@ -656,7 +657,7 @@ class Scheduler(
|
|||||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||||
)
|
)
|
||||||
self.process_batch_result(tmp_batch, None)
|
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
||||||
|
|
||||||
if self.last_batch:
|
if self.last_batch:
|
||||||
# Process the results of the last batch
|
# Process the results of the last batch
|
||||||
@@ -664,7 +665,10 @@ class Scheduler(
|
|||||||
tmp_batch.next_batch_sampling_info = (
|
tmp_batch.next_batch_sampling_info = (
|
||||||
self.tp_worker.cur_sampling_info if batch else None
|
self.tp_worker.cur_sampling_info if batch else None
|
||||||
)
|
)
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
||||||
|
self.process_batch_result(
|
||||||
|
tmp_batch, tmp_result, batch.launch_done if batch else None
|
||||||
|
)
|
||||||
elif batch is None:
|
elif batch is None:
|
||||||
# When the server is idle, do self-check and re-init some states
|
# When the server is idle, do self-check and re-init some states
|
||||||
self.check_memory()
|
self.check_memory()
|
||||||
@@ -1417,14 +1421,15 @@ class Scheduler(
|
|||||||
self,
|
self,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||||
|
launch_done: Optional[threading.Event] = None,
|
||||||
):
|
):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
self.process_batch_result_decode(batch, result)
|
self.process_batch_result_decode(batch, result, launch_done)
|
||||||
elif batch.forward_mode.is_extend():
|
elif batch.forward_mode.is_extend():
|
||||||
self.process_batch_result_prefill(batch, result)
|
self.process_batch_result_prefill(batch, result, launch_done)
|
||||||
elif batch.forward_mode.is_idle():
|
elif batch.forward_mode.is_idle():
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
self.tp_worker.resolve_batch_result(result.bid)
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
self.current_stream.synchronize()
|
self.current_stream.synchronize()
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
@@ -11,6 +12,7 @@ if TYPE_CHECKING:
|
|||||||
EmbeddingBatchResult,
|
EmbeddingBatchResult,
|
||||||
GenerationBatchResult,
|
GenerationBatchResult,
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
|
Scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def process_batch_result_prefill(
|
def process_batch_result_prefill(
|
||||||
self,
|
self: Scheduler,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||||
|
launch_done: Optional[threading.Event] = None,
|
||||||
):
|
):
|
||||||
skip_stream_req = None
|
skip_stream_req = None
|
||||||
|
|
||||||
@@ -43,7 +46,11 @@ class SchedulerOutputProcessorMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
logits_output, next_token_ids = (
|
||||||
|
self.tp_worker.resolve_last_batch_result(
|
||||||
|
launch_done,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Move next_token_ids and logprobs to cpu
|
# Move next_token_ids and logprobs to cpu
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
@@ -175,9 +182,10 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||||
|
|
||||||
def process_batch_result_decode(
|
def process_batch_result_decode(
|
||||||
self,
|
self: Scheduler,
|
||||||
batch: ScheduleBatch,
|
batch: ScheduleBatch,
|
||||||
result: GenerationBatchResult,
|
result: GenerationBatchResult,
|
||||||
|
launch_done: Optional[threading.Event] = None,
|
||||||
):
|
):
|
||||||
logits_output, next_token_ids, bid = (
|
logits_output, next_token_ids, bid = (
|
||||||
result.logits_output,
|
result.logits_output,
|
||||||
@@ -187,7 +195,9 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
|
||||||
|
launch_done
|
||||||
|
)
|
||||||
next_token_logprobs = logits_output.next_token_logprobs
|
next_token_logprobs = logits_output.next_token_logprobs
|
||||||
elif batch.spec_algorithm.is_none():
|
elif batch.spec_algorithm.is_none():
|
||||||
# spec decoding handles output logprobs inside verify process.
|
# spec decoding handles output logprobs inside verify process.
|
||||||
@@ -271,7 +281,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.log_decode_stats()
|
self.log_decode_stats()
|
||||||
|
|
||||||
def add_input_logprob_return_values(
|
def add_input_logprob_return_values(
|
||||||
self,
|
self: Scheduler,
|
||||||
i: int,
|
i: int,
|
||||||
req: Req,
|
req: Req,
|
||||||
output: LogitsProcessorOutput,
|
output: LogitsProcessorOutput,
|
||||||
@@ -405,7 +415,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
|
||||||
|
|
||||||
def add_logprob_return_values(
|
def add_logprob_return_values(
|
||||||
self,
|
self: Scheduler,
|
||||||
i: int,
|
i: int,
|
||||||
req: Req,
|
req: Req,
|
||||||
pt: int,
|
pt: int,
|
||||||
@@ -436,7 +446,10 @@ class SchedulerOutputProcessorMixin:
|
|||||||
return num_input_logprobs
|
return num_input_logprobs
|
||||||
|
|
||||||
def stream_output(
|
def stream_output(
|
||||||
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
self: Scheduler,
|
||||||
|
reqs: List[Req],
|
||||||
|
return_logprob: bool,
|
||||||
|
skip_req: Optional[Req] = None,
|
||||||
):
|
):
|
||||||
"""Stream the output to detokenizer."""
|
"""Stream the output to detokenizer."""
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
@@ -445,7 +458,10 @@ class SchedulerOutputProcessorMixin:
|
|||||||
self.stream_output_embedding(reqs)
|
self.stream_output_embedding(reqs)
|
||||||
|
|
||||||
def stream_output_generation(
|
def stream_output_generation(
|
||||||
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
self: Scheduler,
|
||||||
|
reqs: List[Req],
|
||||||
|
return_logprob: bool,
|
||||||
|
skip_req: Optional[Req] = None,
|
||||||
):
|
):
|
||||||
rids = []
|
rids = []
|
||||||
finished_reasons: List[BaseFinishReason] = []
|
finished_reasons: List[BaseFinishReason] = []
|
||||||
@@ -593,7 +609,7 @@ class SchedulerOutputProcessorMixin:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_output_embedding(self, reqs: List[Req]):
|
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
|
||||||
rids = []
|
rids = []
|
||||||
finished_reasons: List[BaseFinishReason] = []
|
finished_reasons: List[BaseFinishReason] = []
|
||||||
|
|
||||||
|
|||||||
@@ -170,13 +170,13 @@ class TpModelWorker:
|
|||||||
def forward_batch_generation(
|
def forward_batch_generation(
|
||||||
self,
|
self,
|
||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
launch_done: Optional[threading.Event] = None,
|
|
||||||
skip_sample: bool = False,
|
skip_sample: bool = False,
|
||||||
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
logits_output = self.model_runner.forward(forward_batch)
|
||||||
if launch_done:
|
|
||||||
launch_done.set()
|
if model_worker_batch.launch_done is not None:
|
||||||
|
model_worker_batch.launch_done.set()
|
||||||
|
|
||||||
if skip_sample:
|
if skip_sample:
|
||||||
next_token_ids = None
|
next_token_ids = None
|
||||||
|
|||||||
@@ -132,7 +132,6 @@ class TpModelWorkerClient:
|
|||||||
batch_pt += 1
|
batch_pt += 1
|
||||||
|
|
||||||
# Create event
|
# Create event
|
||||||
self.launch_done = threading.Event()
|
|
||||||
copy_done = torch.get_device_module(self.device).Event()
|
copy_done = torch.get_device_module(self.device).Event()
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
@@ -141,7 +140,7 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||||
model_worker_batch, self.launch_done
|
model_worker_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the future token ids map
|
# Update the future token ids map
|
||||||
@@ -168,10 +167,16 @@ class TpModelWorkerClient:
|
|||||||
|
|
||||||
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
||||||
|
|
||||||
def resolve_batch_result(self, bid: int):
|
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
||||||
|
"""
|
||||||
|
This function is called to resolve the last batch result and
|
||||||
|
wait for the current batch to be launched. Used in overlap mode.
|
||||||
|
"""
|
||||||
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
||||||
|
|
||||||
|
if launch_done is not None:
|
||||||
|
launch_done.wait()
|
||||||
copy_done.synchronize()
|
copy_done.synchronize()
|
||||||
self.launch_done.wait()
|
|
||||||
|
|
||||||
if logits_output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
logits_output.next_token_logprobs = (
|
logits_output.next_token_logprobs = (
|
||||||
|
|||||||
Reference in New Issue
Block a user