Fix the overlap for xgrammar (#2377)
This commit is contained in:
@@ -106,4 +106,4 @@ def import_new_model_classes():
|
|||||||
ModelRegistry.models.update(import_new_model_classes())
|
ModelRegistry.models.update(import_new_model_classes())
|
||||||
|
|
||||||
launch_server(server_args)
|
launch_server(server_args)
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
|
|||||||
self.guide = guide
|
self.guide = guide
|
||||||
self.jump_forward_map = jump_forward_map
|
self.jump_forward_map = jump_forward_map
|
||||||
self.state = 0
|
self.state = 0
|
||||||
|
self.finished = False
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
self.state = self.guide.get_next_state(self.state, token)
|
self.state = self.guide.get_next_state(self.state, token)
|
||||||
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||||
|
return vocab_mask
|
||||||
|
|
||||||
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
|
||||||
tokens = torch.tensor(
|
tokens = torch.tensor(
|
||||||
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.ctx = ctx
|
self.ctx = ctx
|
||||||
|
self.finished = False
|
||||||
|
|
||||||
def accept_token(self, token: int):
|
def accept_token(self, token: int):
|
||||||
assert self.matcher.accept_token(token)
|
assert self.matcher.accept_token(token)
|
||||||
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
|
|||||||
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
self.matcher.fill_next_token_bitmask(vocab_mask, idx)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
|
||||||
if vocab_mask.device.type != logits.device.type:
|
return vocab_mask.to(device, non_blocking=True)
|
||||||
# vocab_mask must then be on the same device as logits
|
|
||||||
# when applying the token bitmask, so we check and move if needed
|
|
||||||
vocab_mask = vocab_mask.to(logits.device)
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
|
||||||
apply_token_bitmask_inplace(logits, vocab_mask)
|
apply_token_bitmask_inplace(logits, vocab_mask)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
|||||||
@@ -114,9 +114,6 @@ class Scheduler:
|
|||||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||||
self.enable_metrics = server_args.enable_metrics
|
self.enable_metrics = server_args.enable_metrics
|
||||||
|
|
||||||
# Session info
|
|
||||||
self.sessions = {}
|
|
||||||
|
|
||||||
# Init inter-process communication
|
# Init inter-process communication
|
||||||
context = zmq.Context(2)
|
context = zmq.Context(2)
|
||||||
|
|
||||||
@@ -259,6 +256,10 @@ class Scheduler:
|
|||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
self.last_decode_stats_tic = time.time()
|
self.last_decode_stats_tic = time.time()
|
||||||
self.stream_interval = server_args.stream_interval
|
self.stream_interval = server_args.stream_interval
|
||||||
|
self.current_stream = torch.get_device_module(self.device).current_stream()
|
||||||
|
|
||||||
|
# Session info
|
||||||
|
self.sessions = {}
|
||||||
|
|
||||||
# Init chunked prefill
|
# Init chunked prefill
|
||||||
self.chunked_prefill_size = server_args.chunked_prefill_size
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
||||||
@@ -356,6 +357,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def watchdog_thread(self):
|
def watchdog_thread(self):
|
||||||
|
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
||||||
self.watchdog_last_forward_ct = 0
|
self.watchdog_last_forward_ct = 0
|
||||||
self.watchdog_last_time = time.time()
|
self.watchdog_last_time = time.time()
|
||||||
|
|
||||||
@@ -433,61 +435,6 @@ class Scheduler:
|
|||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
|
||||||
# Check if other DP workers have running batches
|
|
||||||
if local_batch is None:
|
|
||||||
num_tokens = 0
|
|
||||||
elif local_batch.forward_mode.is_decode():
|
|
||||||
num_tokens = local_batch.batch_size()
|
|
||||||
else:
|
|
||||||
num_tokens = local_batch.extend_num_tokens
|
|
||||||
|
|
||||||
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
|
||||||
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
|
||||||
global_num_tokens,
|
|
||||||
local_num_tokens,
|
|
||||||
group=self.tp_cpu_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
|
||||||
local_batch = self.get_idle_batch()
|
|
||||||
|
|
||||||
if local_batch is not None:
|
|
||||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
|
||||||
|
|
||||||
# Check forward mode for cuda graph
|
|
||||||
if not self.server_args.disable_cuda_graph:
|
|
||||||
forward_mode_state = torch.tensor(
|
|
||||||
(
|
|
||||||
1
|
|
||||||
if local_batch.forward_mode.is_decode()
|
|
||||||
or local_batch.forward_mode.is_idle()
|
|
||||||
else 0
|
|
||||||
),
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
torch.distributed.all_reduce(
|
|
||||||
forward_mode_state,
|
|
||||||
op=torch.distributed.ReduceOp.MIN,
|
|
||||||
group=self.tp_cpu_group,
|
|
||||||
)
|
|
||||||
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
|
||||||
|
|
||||||
return local_batch
|
|
||||||
|
|
||||||
def get_idle_batch(self):
|
|
||||||
idle_batch = ScheduleBatch.init_new(
|
|
||||||
[],
|
|
||||||
self.req_to_token_pool,
|
|
||||||
self.token_to_kv_pool,
|
|
||||||
self.tree_cache,
|
|
||||||
self.model_config,
|
|
||||||
self.enable_overlap,
|
|
||||||
)
|
|
||||||
idle_batch.prepare_for_idle()
|
|
||||||
return idle_batch
|
|
||||||
|
|
||||||
def recv_requests(self):
|
def recv_requests(self):
|
||||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
@@ -993,7 +940,7 @@ class Scheduler:
|
|||||||
self.process_batch_result_prefill(batch, result)
|
self.process_batch_result_prefill(batch, result)
|
||||||
elif batch.forward_mode.is_dummy_first():
|
elif batch.forward_mode.is_dummy_first():
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
torch.get_device_module(self.device).current_stream().synchronize()
|
self.current_stream.synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
||||||
@@ -1049,13 +996,14 @@ class Scheduler:
|
|||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
req.grammar.accept_token(next_token_id)
|
req.grammar.accept_token(next_token_id)
|
||||||
|
req.grammar.finished = req.finished()
|
||||||
else:
|
else:
|
||||||
# being chunked reqs' prefill is not finished
|
# being chunked reqs' prefill is not finished
|
||||||
req.is_being_chunked -= 1
|
req.is_being_chunked -= 1
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
torch.get_device_module(self.device).current_stream().synchronize()
|
self.current_stream.synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
@@ -1127,10 +1075,11 @@ class Scheduler:
|
|||||||
|
|
||||||
if req.grammar is not None:
|
if req.grammar is not None:
|
||||||
req.grammar.accept_token(next_token_id)
|
req.grammar.accept_token(next_token_id)
|
||||||
|
req.grammar.finished = req.finished()
|
||||||
|
|
||||||
if batch.next_batch_sampling_info:
|
if batch.next_batch_sampling_info:
|
||||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||||
torch.get_device_module(self.device).current_stream().synchronize()
|
self.current_stream.synchronize()
|
||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
|
|
||||||
self.stream_output(batch.reqs)
|
self.stream_output(batch.reqs)
|
||||||
@@ -1328,6 +1277,61 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||||
|
# Check if other DP workers have running batches
|
||||||
|
if local_batch is None:
|
||||||
|
num_tokens = 0
|
||||||
|
elif local_batch.forward_mode.is_decode():
|
||||||
|
num_tokens = local_batch.batch_size()
|
||||||
|
else:
|
||||||
|
num_tokens = local_batch.extend_num_tokens
|
||||||
|
|
||||||
|
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
||||||
|
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
global_num_tokens,
|
||||||
|
local_num_tokens,
|
||||||
|
group=self.tp_cpu_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||||
|
local_batch = self.get_idle_batch()
|
||||||
|
|
||||||
|
if local_batch is not None:
|
||||||
|
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||||
|
|
||||||
|
# Check forward mode for cuda graph
|
||||||
|
if not self.server_args.disable_cuda_graph:
|
||||||
|
forward_mode_state = torch.tensor(
|
||||||
|
(
|
||||||
|
1
|
||||||
|
if local_batch.forward_mode.is_decode()
|
||||||
|
or local_batch.forward_mode.is_idle()
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
forward_mode_state,
|
||||||
|
op=torch.distributed.ReduceOp.MIN,
|
||||||
|
group=self.tp_cpu_group,
|
||||||
|
)
|
||||||
|
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
||||||
|
|
||||||
|
return local_batch
|
||||||
|
|
||||||
|
def get_idle_batch(self):
|
||||||
|
idle_batch = ScheduleBatch.init_new(
|
||||||
|
[],
|
||||||
|
self.req_to_token_pool,
|
||||||
|
self.token_to_kv_pool,
|
||||||
|
self.tree_cache,
|
||||||
|
self.model_config,
|
||||||
|
self.enable_overlap,
|
||||||
|
)
|
||||||
|
idle_batch.prepare_for_idle()
|
||||||
|
return idle_batch
|
||||||
|
|
||||||
def move_ready_grammar_requests(self):
|
def move_ready_grammar_requests(self):
|
||||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||||
num_ready_reqs = 0
|
num_ready_reqs = 0
|
||||||
@@ -1469,10 +1473,6 @@ def run_scheduler_process(
|
|||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
pipe_writer,
|
pipe_writer,
|
||||||
):
|
):
|
||||||
# set cpu affinity to this gpu process
|
|
||||||
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
|
||||||
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
|
||||||
|
|
||||||
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
||||||
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
||||||
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
||||||
@@ -1482,6 +1482,10 @@ def run_scheduler_process(
|
|||||||
else:
|
else:
|
||||||
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
||||||
|
|
||||||
|
# set cpu affinity to this gpu process
|
||||||
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
||||||
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
||||||
|
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
parent_process = psutil.Process().parent()
|
parent_process = psutil.Process().parent()
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ class TpModelWorkerClient:
|
|||||||
)
|
)
|
||||||
self.forward_thread.start()
|
self.forward_thread.start()
|
||||||
self.parent_process = psutil.Process().parent()
|
self.parent_process = psutil.Process().parent()
|
||||||
|
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
||||||
|
|
||||||
def get_worker_info(self):
|
def get_worker_info(self):
|
||||||
return self.worker.get_worker_info()
|
return self.worker.get_worker_info()
|
||||||
@@ -191,7 +192,7 @@ class TpModelWorkerClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||||
torch.get_device_module(self.device).current_stream().synchronize()
|
self.scheduler_stream.synchronize()
|
||||||
|
|
||||||
# Push a new batch to the queue
|
# Push a new batch to the queue
|
||||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||||
|
|||||||
@@ -158,22 +158,23 @@ class SamplingBatchInfo:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# find a grammar from the list
|
# find a grammar from the list
|
||||||
grammar = next(grammar for grammar in self.grammars if grammar)
|
first_grammar = next(grammar for grammar in self.grammars if grammar)
|
||||||
|
|
||||||
# maybe we can reuse the existing mask?
|
# maybe we can reuse the existing mask?
|
||||||
self.vocab_mask = grammar.allocate_vocab_mask(
|
self.vocab_mask = first_grammar.allocate_vocab_mask(
|
||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
batch_size=len(self.temperatures),
|
batch_size=len(self.temperatures),
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
|
self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
|
||||||
|
|
||||||
|
# Apply the mask
|
||||||
for i, grammar in enumerate(self.grammars):
|
for i, grammar in enumerate(self.grammars):
|
||||||
if grammar is not None:
|
if grammar and not grammar.finished:
|
||||||
try:
|
grammar.fill_vocab_mask(self.vocab_mask, i)
|
||||||
grammar.fill_vocab_mask(self.vocab_mask, i)
|
|
||||||
except RuntimeError:
|
# Move the mask to the device if needed
|
||||||
continue
|
self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate
|
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
|
||||||
|
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -11,38 +12,50 @@ import requests
|
|||||||
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_class(cls, backend: str, disable_overlap: bool):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.json_schema = json.dumps(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
||||||
|
"population": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["name", "population"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
other_args = [
|
||||||
|
"--max-running-requests",
|
||||||
|
"10",
|
||||||
|
"--grammar-backend",
|
||||||
|
backend,
|
||||||
|
]
|
||||||
|
|
||||||
|
if disable_overlap:
|
||||||
|
other_args += ["--disable-overlap-schedule"]
|
||||||
|
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
setup_class(cls, backend="outlines", disable_overlap=False)
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.check_jump_forward = False
|
||||||
cls.json_schema = json.dumps(
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string", "pattern": "^[\\w]+$"},
|
|
||||||
"population": {"type": "integer"},
|
|
||||||
},
|
|
||||||
"required": ["name", "population"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=300,
|
|
||||||
other_args=[
|
|
||||||
"--max-running-requests",
|
|
||||||
"10",
|
|
||||||
"--grammar-backend",
|
|
||||||
"outlines",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
@@ -83,11 +96,13 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
|||||||
self.assertIsInstance(js_obj["population"], int)
|
self.assertIsInstance(js_obj["population"], int)
|
||||||
|
|
||||||
# Make sure jump forward is triggered
|
# Make sure jump forward is triggered
|
||||||
# NOTE: This is skipped because overlap scheduler does not support jump forward
|
# NOTE: The overlap scheduler does not support jump forward so we only do this test
|
||||||
# self.assertGreater(
|
# when --disable-overlap-schedule is set.
|
||||||
# ret["meta_info"]["completion_tokens"],
|
if self.check_jump_forward:
|
||||||
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
|
self.assertGreater(
|
||||||
# )
|
ret["meta_info"]["completion_tokens"],
|
||||||
|
ret["meta_info"]["completion_tokens_wo_jump_forward"],
|
||||||
|
)
|
||||||
|
|
||||||
def test_json_generate(self):
|
def test_json_generate(self):
|
||||||
self.run_decode(json_schema=self.json_schema)
|
self.run_decode(json_schema=self.json_schema)
|
||||||
@@ -126,32 +141,18 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
|
|||||||
list(executor.map(self.run_decode, json_schemas))
|
list(executor.map(self.run_decode, json_schemas))
|
||||||
|
|
||||||
|
|
||||||
|
class TestJumpForwardOutlinesBackend(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
setup_class(cls, backend="outlines", disable_overlap=True)
|
||||||
|
cls.check_jump_forward = True
|
||||||
|
|
||||||
|
|
||||||
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
setup_class(cls, backend="xgrammar", disable_overlap=False)
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.check_jump_forward = False
|
||||||
cls.json_schema = json.dumps(
|
|
||||||
{
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"population": {"type": "integer"},
|
|
||||||
},
|
|
||||||
"required": ["name", "population"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=300,
|
|
||||||
other_args=[
|
|
||||||
"--max-running-requests",
|
|
||||||
"10",
|
|
||||||
"--grammar-backend",
|
|
||||||
"xgrammar",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user