Fix grammar backend for tensor parallelism (#2020)

This commit is contained in:
Lianmin Zheng
2024-11-13 01:49:45 -08:00
committed by GitHub
parent ba069a24d3
commit 54479d6f30
7 changed files with 250 additions and 328 deletions

View File

@@ -244,7 +244,7 @@ class Scheduler:
self.grammar_backend = OutlinesGrammarBackend(
self.tokenizer,
whitespace_patterns=server_args.constrained_json_whitespace_pattern,
whitespace_pattern=server_args.constrained_json_whitespace_pattern,
allow_jump_forward=not server_args.disable_jump_forward,
)
elif server_args.grammar_backend == "xgrammar":
@@ -467,21 +467,6 @@ class Scheduler:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(recv_req.input_ids) - 1
# Init grammar cache for this request
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
req.grammar = self.grammar_backend.query(
("json", req.sampling_params.json_schema),
)
elif req.sampling_params.regex is not None:
req.grammar = self.grammar_backend.query(
("regex", req.sampling_params.regex)
)
# Truncate prompts that are too long
if len(req.origin_input_ids) > self.max_req_input_len:
logger.warning(
@@ -499,7 +484,24 @@ class Scheduler:
self.max_req_len - len(req.origin_input_ids) - 1,
)
if req.grammar is not None:
# Init grammar cache for this request
add_to_grammar_queue = False
if (
req.sampling_params.json_schema is not None
or req.sampling_params.regex is not None
):
assert self.grammar_backend is not None
if req.sampling_params.json_schema is not None:
key = ("json", req.sampling_params.json_schema)
elif req.sampling_params.regex is not None:
key = ("regex", req.sampling_params.regex)
req.grammar = self.grammar_backend.get_cached_value(key)
if not req.grammar:
req.grammar = self.grammar_backend.get_future_value(key)
add_to_grammar_queue = True
if add_to_grammar_queue:
self.grammar_queue.append(req)
else:
self.waiting_queue.append(req)
@@ -650,14 +652,7 @@ class Scheduler:
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue
if self.grammar_queue:
new_grammar_queue = []
for req in self.grammar_queue:
try:
req.grammar = req.grammar.result(timeout=0.05)
self.waiting_queue.append(req)
except futures._base.TimeoutError:
new_grammar_queue.append(req)
self.grammar_queue = new_grammar_queue
self.move_ready_grammar_requests()
# Handle the cases where prefill is not allowed
if (
@@ -1145,6 +1140,30 @@ class Scheduler:
)
)
def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0
for req in self.grammar_queue:
try:
req.grammar = req.grammar.result(timeout=0.05)
num_ready_reqs += 1
except futures._base.TimeoutError:
break
if self.tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=self.tp_cpu_group
)
num_ready_reqs_max = tensor.item()
for i in range(num_ready_reqs, num_ready_reqs_max):
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
num_ready_reqs = num_ready_reqs_max
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
@@ -1152,9 +1171,8 @@ class Scheduler:
):
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
if self.grammar_backend is not None:
if self.grammar_backend:
self.grammar_backend.reset()
# TODO(dark): reset the bnf cache
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
torch.cuda.empty_cache()