[PD] Support structured output (#6560)
This commit is contained in:
@@ -45,19 +45,16 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
|
||||||
from sglang.srt.managers.scheduler import Scheduler
|
from sglang.srt.managers.scheduler import Scheduler
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -531,7 +528,18 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.prepare_dp_attn_batch(batch)
|
self.prepare_dp_attn_batch(batch)
|
||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
result_queue.append((batch.copy(), result))
|
result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
|
if (self.last_batch is None) or (not self.last_batch_in_queue):
|
||||||
|
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||||
|
# It is now used for triggering the sampling_info_done event.
|
||||||
|
tmp_batch = ScheduleBatch(
|
||||||
|
reqs=None,
|
||||||
|
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||||
|
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||||
|
)
|
||||||
|
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||||
last_batch_in_queue = True
|
last_batch_in_queue = True
|
||||||
|
|
||||||
elif prepare_dp_attn_flag:
|
elif prepare_dp_attn_flag:
|
||||||
batch, result = self._prepare_idle_batch_and_run(
|
batch, result = self._prepare_idle_batch_and_run(
|
||||||
None, delay_process=True
|
None, delay_process=True
|
||||||
@@ -543,6 +551,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
# Process the results of the previous batch but skip if the last batch is extend
|
# Process the results of the previous batch but skip if the last batch is extend
|
||||||
if self.last_batch and self.last_batch_in_queue:
|
if self.last_batch and self.last_batch_in_queue:
|
||||||
tmp_batch, tmp_result = result_queue.popleft()
|
tmp_batch, tmp_result = result_queue.popleft()
|
||||||
|
tmp_batch.next_batch_sampling_info = (
|
||||||
|
self.tp_worker.cur_sampling_info if batch else None
|
||||||
|
)
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
|
|
||||||
if batch is None and (
|
if batch is None and (
|
||||||
@@ -591,6 +602,9 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
|
|
||||||
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
|
||||||
"""Create a schedulebatch for fake completed prefill"""
|
"""Create a schedulebatch for fake completed prefill"""
|
||||||
|
if self.grammar_queue:
|
||||||
|
self.move_ready_grammar_requests()
|
||||||
|
|
||||||
if len(self.waiting_queue) == 0:
|
if len(self.waiting_queue) == 0:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -616,8 +630,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.waiting_queue = waiting_queue
|
self.waiting_queue = waiting_queue
|
||||||
if len(can_run_list) == 0:
|
if len(can_run_list) == 0:
|
||||||
return None
|
return None
|
||||||
# local import to avoid circular import
|
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
||||||
|
|
||||||
# construct a schedule batch with those requests and mark as decode
|
# construct a schedule batch with those requests and mark as decode
|
||||||
new_batch = ScheduleBatch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
|
|||||||
@@ -101,6 +101,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|||||||
for req in self.reqs:
|
for req in self.reqs:
|
||||||
self.output_ids.append(req.output_ids[-1])
|
self.output_ids.append(req.output_ids[-1])
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
if req.grammar is not None:
|
||||||
|
req.grammar.accept_token(req.output_ids[-1])
|
||||||
|
req.grammar.finished = req.finished()
|
||||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||||
|
|
||||||
# Simulate the eagle run. We add mock data to hidden states for the
|
# Simulate the eagle run. We add mock data to hidden states for the
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@@ -143,6 +144,10 @@ class PrefillBootstrapQueue:
|
|||||||
self._process_req(req)
|
self._process_req(req)
|
||||||
self.queue.append(req)
|
self.queue.append(req)
|
||||||
|
|
||||||
|
def extend(self, reqs: List[Req]) -> None:
|
||||||
|
for req in reqs:
|
||||||
|
self.add(req)
|
||||||
|
|
||||||
def _process_req(self, req: Req) -> None:
|
def _process_req(self, req: Req) -> None:
|
||||||
"""
|
"""
|
||||||
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
||||||
@@ -269,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
result = self.run_batch(batch)
|
result = self.run_batch(batch)
|
||||||
self.result_queue.append((batch.copy(), result))
|
self.result_queue.append((batch.copy(), result))
|
||||||
|
|
||||||
|
if self.last_batch is None:
|
||||||
|
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||||
|
# It is now used for triggering the sampling_info_done event.
|
||||||
|
tmp_batch = ScheduleBatch(
|
||||||
|
reqs=None,
|
||||||
|
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||||
|
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||||
|
)
|
||||||
|
self.set_next_batch_sampling_info_done(tmp_batch)
|
||||||
|
|
||||||
if self.last_batch:
|
if self.last_batch:
|
||||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||||
|
|||||||
@@ -1065,8 +1065,11 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
|
|
||||||
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
def _extend_requests_to_queue(self, reqs: List[Req]):
|
||||||
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||||
|
self.disagg_prefill_bootstrap_queue.extend(reqs)
|
||||||
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
|
# If this is a decode server, we put the request to the decode pending prealloc queue
|
||||||
self.disagg_decode_prealloc_queue.extend(reqs)
|
self.disagg_decode_prealloc_queue.extend(reqs)
|
||||||
else:
|
else:
|
||||||
self.waiting_queue.extend(reqs)
|
self.waiting_queue.extend(reqs)
|
||||||
|
|||||||
34
scripts/playground/disaggregation/cli-so.py
Normal file
34
scripts/playground/disaggregation/cli-so.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
port = 8000
|
||||||
|
|
||||||
|
json_schema = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||||
|
"population": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "population"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON
|
||||||
|
response = requests.post(
|
||||||
|
f"http://localhost:{port}/generate",
|
||||||
|
json={
|
||||||
|
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 64,
|
||||||
|
"json_schema": json_schema,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.json())
|
||||||
|
|
||||||
|
|
||||||
|
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
@@ -17,12 +18,9 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
popen_launch_pd_server,
|
popen_launch_pd_server,
|
||||||
run_with_timeout,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# skip the test because we have different_tp test
|
|
||||||
@unittest.skip("skip the test because we have different_tp test")
|
|
||||||
class TestDisaggregationAccuracy(CustomTestCase):
|
class TestDisaggregationAccuracy(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -172,6 +170,34 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
|||||||
len(input_logprobs) > 0
|
len(input_logprobs) > 0
|
||||||
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
|
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
|
||||||
|
|
||||||
|
def test_structured_output(self):
|
||||||
|
json_schema = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||||
|
"population": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "population"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.lb_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": "Here is the information of the capital of France in the JSON format.\n",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 64,
|
||||||
|
"json_schema": json_schema,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
output = response.json()["text"]
|
||||||
|
# ensure the output is a valid JSON
|
||||||
|
json.loads(output)
|
||||||
|
|
||||||
|
|
||||||
class TestDisaggregationMooncakeFailure(CustomTestCase):
|
class TestDisaggregationMooncakeFailure(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user