Fix grammar backend for tensor parallelism (#2020)
This commit is contained in:
@@ -244,7 +244,7 @@ class Scheduler:
|
||||
|
||||
self.grammar_backend = OutlinesGrammarBackend(
|
||||
self.tokenizer,
|
||||
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
|
||||
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
allow_jump_forward=not server_args.disable_jump_forward,
|
||||
)
|
||||
elif server_args.grammar_backend == "xgrammar":
|
||||
@@ -467,21 +467,6 @@ class Scheduler:
|
||||
# By default, only return the logprobs for output tokens
|
||||
req.logprob_start_len = len(recv_req.input_ids) - 1
|
||||
|
||||
# Init grammar cache for this request
|
||||
if (
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
):
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
req.grammar = self.grammar_backend.query(
|
||||
("json", req.sampling_params.json_schema),
|
||||
)
|
||||
elif req.sampling_params.regex is not None:
|
||||
req.grammar = self.grammar_backend.query(
|
||||
("regex", req.sampling_params.regex)
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||
logger.warning(
|
||||
@@ -499,7 +484,24 @@ class Scheduler:
|
||||
self.max_req_len - len(req.origin_input_ids) - 1,
|
||||
)
|
||||
|
||||
if req.grammar is not None:
|
||||
# Init grammar cache for this request
|
||||
add_to_grammar_queue = False
|
||||
if (
|
||||
req.sampling_params.json_schema is not None
|
||||
or req.sampling_params.regex is not None
|
||||
):
|
||||
assert self.grammar_backend is not None
|
||||
if req.sampling_params.json_schema is not None:
|
||||
key = ("json", req.sampling_params.json_schema)
|
||||
elif req.sampling_params.regex is not None:
|
||||
key = ("regex", req.sampling_params.regex)
|
||||
|
||||
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||
if not req.grammar:
|
||||
req.grammar = self.grammar_backend.get_future_value(key)
|
||||
add_to_grammar_queue = True
|
||||
|
||||
if add_to_grammar_queue:
|
||||
self.grammar_queue.append(req)
|
||||
else:
|
||||
self.waiting_queue.append(req)
|
||||
@@ -650,14 +652,7 @@ class Scheduler:
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Check if the grammar is ready in the grammar queue
|
||||
if self.grammar_queue:
|
||||
new_grammar_queue = []
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
req.grammar = req.grammar.result(timeout=0.05)
|
||||
self.waiting_queue.append(req)
|
||||
except futures._base.TimeoutError:
|
||||
new_grammar_queue.append(req)
|
||||
self.grammar_queue = new_grammar_queue
|
||||
self.move_ready_grammar_requests()
|
||||
|
||||
# Handle the cases where prefill is not allowed
|
||||
if (
|
||||
@@ -1145,6 +1140,30 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
|
||||
def move_ready_grammar_requests(self):
|
||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||
num_ready_reqs = 0
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
req.grammar = req.grammar.result(timeout=0.05)
|
||||
num_ready_reqs += 1
|
||||
except futures._base.TimeoutError:
|
||||
break
|
||||
|
||||
if self.tp_size > 1:
|
||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
|
||||
)
|
||||
num_ready_reqs_max = tensor.item()
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
||||
num_ready_reqs = num_ready_reqs_max
|
||||
|
||||
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
@@ -1152,9 +1171,8 @@ class Scheduler:
|
||||
):
|
||||
self.tree_cache.reset()
|
||||
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
||||
if self.grammar_backend is not None:
|
||||
if self.grammar_backend:
|
||||
self.grammar_backend.reset()
|
||||
# TODO(dark): reset the bnf cache
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user