Correctly abort the failed grammar requests & Improve the handling of abort (#6803)

This commit is contained in:
Lianmin Zheng
2025-06-01 19:00:07 -07:00
committed by GitHub
parent 6a47b73024
commit 20fd53b8f6
16 changed files with 199 additions and 142 deletions

View File

@@ -35,7 +35,10 @@ from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.constrained.base_grammar_backend import (
INVALID_GRAMMAR_OBJ,
create_grammar_backend,
)
from sglang.srt.disaggregation.decode import (
DecodePreallocQueue,
DecodeTransferQueue,
@@ -949,12 +952,12 @@ class Scheduler(
if self.disaggregation_mode != DisaggregationMode.NULL:
# Invalid request for disaggregated mode
if recv_req.bootstrap_room is None:
error_message = (
error_msg = (
f"Invalid request: Disaggregated request received without "
f"boostrap room id. {req.rid=}"
)
logger.error(error_message)
prepare_abort(req, error_message)
logger.error(error_msg)
prepare_abort(req, error_msg)
self.stream_output([req], req.return_logprob)
return
@@ -985,29 +988,23 @@ class Scheduler(
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
req.set_finish_with_abort(
error_msg=(
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
)
self._add_request_to_queue(req)
return
# Validate prompts length
# Validate prompt length
error_msg = validate_input_length(
req,
self.max_req_input_len,
self.server_args.allow_auto_truncate,
)
if error_msg:
req.origin_input_ids = [0]
req.sampling_params.max_new_tokens = 0
req.set_finish_with_abort(error_msg)
self._add_request_to_queue(req)
return
@@ -1019,12 +1016,9 @@ class Scheduler(
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len >= len(req.origin_input_ids):
req.finished_reason = FINISH_ABORT(
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
HTTPStatus.BAD_REQUEST,
"BadRequestError",
)
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
req.logprob_start_len = len(req.origin_input_ids) - 1
req.set_finish_with_abort(error_msg)
self._add_request_to_queue(req)
return
@@ -1061,6 +1055,10 @@ class Scheduler(
if not cache_hit:
req.grammar_key = key
add_to_grammar_queue = True
else:
if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
error_msg = f"Invalid grammar request with cache hit: {key=}"
req.set_finish_with_abort(error_msg)
if add_to_grammar_queue:
req.queue_time_start = time.perf_counter()
@@ -1108,19 +1106,13 @@ class Scheduler(
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
req.set_finish_with_abort(
error_msg=(
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
req.queue_time_start = time.perf_counter()
self.waiting_queue.append(req)
self._add_request_to_queue(req)
return
# Validate prompts length
@@ -1785,17 +1777,25 @@ class Scheduler(
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0
num_abort_reqs = 0
num_timeout_reqs = 0
for req in self.grammar_queue:
try:
if req.finished(): # It is aborted by AbortReq
num_ready_reqs += 1
continue
req.grammar = req.grammar.result(timeout=0.03)
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
num_ready_reqs += 1
except futures._base.TimeoutError:
req.grammar_wait_ct += 1
# NOTE(lianmin): this timeout is the waiting time of the above line. It is
# not the waiting time from it enters the grammar queue.
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
num_abort_reqs = 1
num_timeout_reqs = 1
break
if self.server_args.enable_dp_attention:
@@ -1807,28 +1807,33 @@ class Scheduler(
if tp_size > 1:
# Sync across TP ranks to make sure they have the same number of ready requests
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
torch.distributed.all_reduce(
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
)
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
for i in range(num_ready_reqs, num_ready_reqs_max):
req = self.grammar_queue[i]
if req.finished(): # It is aborted by AbortReq
continue
req.grammar = req.grammar.result()
if req.grammar:
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
if req.grammar is INVALID_GRAMMAR_OBJ:
req.set_finish_with_abort(
f"Invalid grammar request: {req.grammar_key=}"
)
else:
num_ready_reqs_max = num_ready_reqs
num_timeout_reqs_max = num_timeout_reqs
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
req = self.grammar_queue[i]
req.grammar.cancel()
req.grammar = None
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
logger.error(error_msg)
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
req = self.grammar_queue[i]
req.grammar.cancel()
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
req.set_finish_with_abort(error_msg)
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
@@ -2024,8 +2029,6 @@ class Scheduler(
)
def abort_request(self, recv_req: AbortReq):
# TODO(lmzheng): abort the requests in the grammar queue.
# Delete requests in the waiting queue
to_del = []
for i, req in enumerate(self.waiting_queue):
@@ -2047,8 +2050,16 @@ class Scheduler(
for req in reqs:
if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
# We must use to_abort because it is in a running batch
req.to_abort = True
# Delete the requests in the grammar queue
for req in self.grammar_queue:
if req.rid.startswith(recv_req.rid):
logger.debug(f"Abort grammar queue request. {req.rid=}")
req.grammar.cancel()
req.set_finish_with_abort("Aborted by AbortReq.")
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()