diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 92531aa73..4d6ac1b6f 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -45,19 +45,16 @@ from sglang.srt.disaggregation.utils import ( poll_and_all_reduce, prepare_abort, ) -from sglang.srt.managers.schedule_batch import FINISH_ABORT +from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator -from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode -from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +from sglang.srt.model_executor.forward_batch_info import ForwardMode logger = logging.getLogger(__name__) if TYPE_CHECKING: - from sglang.srt.configs.model_config import ModelConfig - from sglang.srt.managers.schedule_batch import Req, ScheduleBatch + from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.scheduler import Scheduler - from sglang.srt.server_args import ServerArgs @dataclass @@ -531,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin: self.prepare_dp_attn_batch(batch) result = self.run_batch(batch) result_queue.append((batch.copy(), result)) + + if (self.last_batch is None) or (not self.last_batch_in_queue): + # Create a dummy first batch to start the pipeline for overlap schedule. + # It is now used for triggering the sampling_info_done event. + tmp_batch = ScheduleBatch( + reqs=None, + forward_mode=ForwardMode.DUMMY_FIRST, + next_batch_sampling_info=self.tp_worker.cur_sampling_info, + ) + self.set_next_batch_sampling_info_done(tmp_batch) last_batch_in_queue = True + elif prepare_dp_attn_flag: batch, result = self._prepare_idle_batch_and_run( None, delay_process=True @@ -543,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin: # Process the results of the previous batch but skip if the last batch is extend if self.last_batch and self.last_batch_in_queue: tmp_batch, tmp_result = result_queue.popleft() + tmp_batch.next_batch_sampling_info = ( + self.tp_worker.cur_sampling_info if batch else None + ) self.process_batch_result(tmp_batch, tmp_result) if batch is None and ( @@ -591,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin: def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]: """Create a schedulebatch for fake completed prefill""" + if self.grammar_queue: + self.move_ready_grammar_requests() + if len(self.waiting_queue) == 0: return None @@ -616,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin: self.waiting_queue = waiting_queue if len(can_run_list) == 0: return None - # local import to avoid circular import - from sglang.srt.managers.schedule_batch import ScheduleBatch # construct a schedule batch with those requests and mark as decode new_batch = ScheduleBatch.init_new( diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index c05e8231d..38e936106 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -101,6 +101,9 @@ class ScheduleBatchDisaggregationDecodeMixin: for req in self.reqs: self.output_ids.append(req.output_ids[-1]) self.tree_cache.cache_unfinished_req(req) + if req.grammar is not None: + req.grammar.accept_token(req.output_ids[-1]) + req.grammar.finished = req.finished() self.output_ids = torch.tensor(self.output_ids, device=self.device) # Simulate the eagle run. We add mock data to hidden states for the diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 0ed04f06a..8b325811e 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -43,6 +43,7 @@ from sglang.srt.disaggregation.utils import ( prepare_abort, ) from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardMode if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -143,6 +144,10 @@ class PrefillBootstrapQueue: self._process_req(req) self.queue.append(req) + def extend(self, reqs: List[Req]) -> None: + for req in reqs: + self.add(req) + def _process_req(self, req: Req) -> None: """ Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate @@ -269,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin: result = self.run_batch(batch) self.result_queue.append((batch.copy(), result)) + if self.last_batch is None: + # Create a dummy first batch to start the pipeline for overlap schedule. + # It is now used for triggering the sampling_info_done event. + tmp_batch = ScheduleBatch( + reqs=None, + forward_mode=ForwardMode.DUMMY_FIRST, + next_batch_sampling_info=self.tp_worker.cur_sampling_info, + ) + self.set_next_batch_sampling_info_done(tmp_batch) + if self.last_batch: tmp_batch, tmp_result = self.result_queue.popleft() self.process_batch_result_disagg_prefill(tmp_batch, tmp_result) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2d9b840ae..3c449fbea 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1065,8 +1065,11 @@ class Scheduler( else: self.waiting_queue.append(req) - def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): - if self.disaggregation_mode == DisaggregationMode.DECODE: + def _extend_requests_to_queue(self, reqs: List[Req]): + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.disagg_prefill_bootstrap_queue.extend(reqs) + elif self.disaggregation_mode == DisaggregationMode.DECODE: + # If this is a decode server, we put the request to the decode pending prealloc queue self.disagg_decode_prealloc_queue.extend(reqs) else: self.waiting_queue.extend(reqs) diff --git a/scripts/playground/disaggregation/cli-so.py b/scripts/playground/disaggregation/cli-so.py new file mode 100644 index 000000000..7ccafc7ed --- /dev/null +++ b/scripts/playground/disaggregation/cli-so.py @@ -0,0 +1,34 @@ +import json + +import requests + +port = 8000 + +json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } +) + +# JSON +response = requests.post( + f"http://localhost:{port}/generate", + json={ + "text": "Here is the information of the capital of France in the JSON format.\n", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 64, + "json_schema": json_schema, + }, + }, +) + +print(response.json()) + + +# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100 diff --git a/test/srt/test_disaggregation.py b/test/srt/test_disaggregation.py index fda00e249..0ae99547e 100644 --- a/test/srt/test_disaggregation.py +++ b/test/srt/test_disaggregation.py @@ -1,3 +1,4 @@ +import json import os import subprocess import time @@ -17,12 +18,9 @@ from sglang.test.test_utils import ( DEFAULT_URL_FOR_TEST, CustomTestCase, popen_launch_pd_server, - run_with_timeout, ) -# skip the test because we have different_tp test -@unittest.skip("skip the test because we have different_tp test") class TestDisaggregationAccuracy(CustomTestCase): @classmethod def setUpClass(cls): @@ -172,6 +170,34 @@ class TestDisaggregationAccuracy(CustomTestCase): len(input_logprobs) > 0 ), f"input_logprobs should have at least one token, but got {len(input_logprobs)}" + def test_structured_output(self): + json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + + # JSON + response = requests.post( + f"{self.lb_url}/generate", + json={ + "text": "Here is the information of the capital of France in the JSON format.\n", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 64, + "json_schema": json_schema, + }, + }, + ) + output = response.json()["text"] + # ensure the output is a valid JSON + json.loads(output) + class TestDisaggregationMooncakeFailure(CustomTestCase): @classmethod