[PD] Support structured output (#6560)
This commit is contained in:
@@ -45,19 +45,16 @@ from sglang.srt.disaggregation.utils import (
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -531,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
last_batch_in_queue = True
|
||||
|
||||
elif prepare_dp_attn_flag:
|
||||
batch, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
@@ -543,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
if self.last_batch and self.last_batch_in_queue:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
if batch is None and (
|
||||
@@ -591,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||
"""Create a schedulebatch for fake completed prefill"""
|
||||
if self.grammar_queue:
|
||||
self.move_ready_grammar_requests()
|
||||
|
||||
if len(self.waiting_queue) == 0:
|
||||
return None
|
||||
|
||||
@@ -616,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.waiting_queue = waiting_queue
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
# local import to avoid circular import
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
# construct a schedule batch with those requests and mark as decode
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
|
||||
@@ -101,6 +101,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
for req in self.reqs:
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(req.output_ids[-1])
|
||||
req.grammar.finished = req.finished()
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
# Simulate the eagle run. We add mock data to hidden states for the
|
||||
|
||||
@@ -43,6 +43,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -143,6 +144,10 @@ class PrefillBootstrapQueue:
|
||||
self._process_req(req)
|
||||
self.queue.append(req)
|
||||
|
||||
def extend(self, reqs: List[Req]) -> None:
|
||||
for req in reqs:
|
||||
self.add(req)
|
||||
|
||||
def _process_req(self, req: Req) -> None:
|
||||
"""
|
||||
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
||||
@@ -269,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||
|
||||
@@ -1065,8 +1065,11 @@ class Scheduler(
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
|
||||
Reference in New Issue
Block a user