Improve structured outputs: fix race condition, server crash, metrics and style (#6188)
This commit is contained in:
@@ -149,6 +149,7 @@ logger = logging.getLogger(__name__)
|
||||
# Test retract decode for debugging purposes
|
||||
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1024,9 +1025,11 @@ class Scheduler(
|
||||
elif req.sampling_params.structural_tag:
|
||||
key = ("structural_tag", req.sampling_params.structural_tag)
|
||||
|
||||
req.grammar = self.grammar_backend.get_cached_value(key)
|
||||
if not req.grammar:
|
||||
req.grammar = self.grammar_backend.get_future_value(key)
|
||||
value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
|
||||
req.grammar = value
|
||||
|
||||
if not cache_hit:
|
||||
req.grammar_key = key
|
||||
add_to_grammar_queue = True
|
||||
|
||||
if add_to_grammar_queue:
|
||||
@@ -1208,6 +1211,7 @@ class Scheduler(
|
||||
self.stats.cache_hit_rate = 0.0
|
||||
self.stats.gen_throughput = self.last_gen_throughput
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.stats.spec_accept_length = spec_accept_length
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
|
||||
@@ -1255,6 +1259,7 @@ class Scheduler(
|
||||
self.stats.token_usage = num_used / self.max_total_num_tokens
|
||||
self.stats.gen_throughput = 0
|
||||
self.stats.num_queue_reqs = len(self.waiting_queue)
|
||||
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
|
||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||
@@ -1715,11 +1720,17 @@ class Scheduler(
|
||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||
|
||||
num_ready_reqs = 0
|
||||
num_abort_reqs = 0
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
req.grammar = req.grammar.result(timeout=0.05)
|
||||
req.grammar = req.grammar.result(timeout=0.03)
|
||||
if req.grammar:
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
num_ready_reqs += 1
|
||||
except futures._base.TimeoutError:
|
||||
req.grammar_wait_ct += 1
|
||||
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
||||
num_abort_reqs = 1
|
||||
break
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
@@ -1731,14 +1742,28 @@ class Scheduler(
|
||||
|
||||
if tp_size > 1:
|
||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
|
||||
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
||||
)
|
||||
num_ready_reqs_max = tensor.item()
|
||||
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
|
||||
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
|
||||
num_ready_reqs = num_ready_reqs_max
|
||||
req = self.grammar_queue[i]
|
||||
req.grammar = req.grammar.result()
|
||||
if req.grammar:
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
|
||||
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
|
||||
req = self.grammar_queue[i]
|
||||
req.grammar.cancel()
|
||||
req.grammar = None
|
||||
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
||||
logger.error(error_msg)
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
)
|
||||
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
|
||||
|
||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
Reference in New Issue
Block a user