[PD] Support structured output (#6560)
This commit is contained in:
@@ -45,19 +45,16 @@ from sglang.srt.disaggregation.utils import (
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -531,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.prepare_dp_attn_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
|
||||
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
last_batch_in_queue = True
|
||||
|
||||
elif prepare_dp_attn_flag:
|
||||
batch, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
@@ -543,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
# Process the results of the previous batch but skip if the last batch is extend
|
||||
if self.last_batch and self.last_batch_in_queue:
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
if batch is None and (
|
||||
@@ -591,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||
"""Create a schedulebatch for fake completed prefill"""
|
||||
if self.grammar_queue:
|
||||
self.move_ready_grammar_requests()
|
||||
|
||||
if len(self.waiting_queue) == 0:
|
||||
return None
|
||||
|
||||
@@ -616,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.waiting_queue = waiting_queue
|
||||
if len(can_run_list) == 0:
|
||||
return None
|
||||
# local import to avoid circular import
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
# construct a schedule batch with those requests and mark as decode
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
|
||||
@@ -101,6 +101,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
for req in self.reqs:
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(req.output_ids[-1])
|
||||
req.grammar.finished = req.finished()
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
# Simulate the eagle run. We add mock data to hidden states for the
|
||||
|
||||
@@ -43,6 +43,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -143,6 +144,10 @@ class PrefillBootstrapQueue:
|
||||
self._process_req(req)
|
||||
self.queue.append(req)
|
||||
|
||||
def extend(self, reqs: List[Req]) -> None:
|
||||
for req in reqs:
|
||||
self.add(req)
|
||||
|
||||
def _process_req(self, req: Req) -> None:
|
||||
"""
|
||||
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
||||
@@ -269,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||
|
||||
@@ -1065,8 +1065,11 @@ class Scheduler(
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||
else:
|
||||
self.waiting_queue.extend(reqs)
|
||||
|
||||
34
scripts/playground/disaggregation/cli-so.py
Normal file
34
scripts/playground/disaggregation/cli-so.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
port = 8000
|
||||
|
||||
json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
)
|
||||
|
||||
# JSON
|
||||
response = requests.post(
|
||||
f"http://localhost:{port}/generate",
|
||||
json={
|
||||
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
|
||||
|
||||
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
@@ -17,12 +18,9 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_pd_server,
|
||||
run_with_timeout,
|
||||
)
|
||||
|
||||
|
||||
# skip the test because we have different_tp test
|
||||
@unittest.skip("skip the test because we have different_tp test")
|
||||
class TestDisaggregationAccuracy(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -172,6 +170,34 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
len(input_logprobs) > 0
|
||||
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
|
||||
|
||||
def test_structured_output(self):
|
||||
json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||
"population": {"type": "integer"},
|
||||
},
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
)
|
||||
|
||||
# JSON
|
||||
response = requests.post(
|
||||
f"{self.lb_url}/generate",
|
||||
json={
|
||||
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
},
|
||||
)
|
||||
output = response.json()["text"]
|
||||
# ensure the output is a valid JSON
|
||||
json.loads(output)
|
||||
|
||||
|
||||
class TestDisaggregationMooncakeFailure(CustomTestCase):
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user