Correctly abort the failed grammar requests & Improve the handling of abort (#6803)
This commit is contained in:
@@ -37,6 +37,7 @@ import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from enum import Enum, auto
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -51,6 +52,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
|
||||
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
@@ -60,7 +62,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
|
||||
from sglang.srt.utils import flatten_nested_list, support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
@@ -771,6 +773,16 @@ class Req:
|
||||
logger.info(f"{prefix}: {self.time_stats}")
|
||||
self.has_log_time_stats = True
|
||||
|
||||
def set_finish_with_abort(self, error_msg: str):
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
logger.error(f"{error_msg}, {self.rid=}")
|
||||
self.multimodal_inputs = None
|
||||
self.grammar = None
|
||||
self.origin_input_ids = [0] # set it to one token to skip the long prefill
|
||||
self.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Req(rid={self.rid}, "
|
||||
|
||||
@@ -35,7 +35,10 @@ from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
|
||||
from sglang.srt.constrained.base_grammar_backend import (
|
||||
INVALID_GRAMMAR_OBJ,
|
||||
create_grammar_backend,
|
||||
)
|
||||
from sglang.srt.disaggregation.decode import (
|
||||
DecodePreallocQueue,
|
||||
DecodeTransferQueue,
|
||||
@@ -949,12 +952,12 @@ class Scheduler(
|
||||
if self.disaggregation_mode != DisaggregationMode.NULL:
|
||||
# Invalid request for disaggregated mode
|
||||
if recv_req.bootstrap_room is None:
|
||||
error_message = (
|
||||
error_msg = (
|
||||
f"Invalid request: Disaggregated request received without "
|
||||
f"boostrap room id. {req.rid=}"
|
||||
)
|
||||
logger.error(error_message)
|
||||
prepare_abort(req, error_message)
|
||||
logger.error(error_msg)
|
||||
prepare_abort(req, error_msg)
|
||||
self.stream_output([req], req.return_logprob)
|
||||
return
|
||||
|
||||
@@ -985,29 +988,23 @@ class Scheduler(
|
||||
req.extend_image_inputs(image_inputs)
|
||||
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
error_msg = (
|
||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
req.origin_input_ids = [0]
|
||||
req.multimodal_inputs = None
|
||||
req.sampling_params.max_new_tokens = 0
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
req.set_finish_with_abort(
|
||||
error_msg=(
|
||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||
)
|
||||
)
|
||||
self._add_request_to_queue(req)
|
||||
return
|
||||
|
||||
# Validate prompts length
|
||||
# Validate prompt length
|
||||
error_msg = validate_input_length(
|
||||
req,
|
||||
self.max_req_input_len,
|
||||
self.server_args.allow_auto_truncate,
|
||||
)
|
||||
if error_msg:
|
||||
req.origin_input_ids = [0]
|
||||
req.sampling_params.max_new_tokens = 0
|
||||
req.set_finish_with_abort(error_msg)
|
||||
self._add_request_to_queue(req)
|
||||
return
|
||||
|
||||
@@ -1019,12 +1016,9 @@ class Scheduler(
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
|
||||
if req.logprob_start_len >= len(req.origin_input_ids):
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"BadRequestError",
|
||||
)
|
||||
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
req.set_finish_with_abort(error_msg)
|
||||
self._add_request_to_queue(req)
|
||||
return
|
||||
|
||||
@@ -1061,6 +1055,10 @@ class Scheduler(
|
||||
if not cache_hit:
|
||||
req.grammar_key = key
|
||||
add_to_grammar_queue = True
|
||||
else:
|
||||
if value is INVALID_GRAMMAR_OBJ: # We hit a cached invalid grammar.
|
||||
error_msg = f"Invalid grammar request with cache hit: {key=}"
|
||||
req.set_finish_with_abort(error_msg)
|
||||
|
||||
if add_to_grammar_queue:
|
||||
req.queue_time_start = time.perf_counter()
|
||||
@@ -1108,19 +1106,13 @@ class Scheduler(
|
||||
req.extend_image_inputs(image_inputs)
|
||||
|
||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||
error_msg = (
|
||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||
req.set_finish_with_abort(
|
||||
error_msg=(
|
||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||
)
|
||||
)
|
||||
logger.error(error_msg)
|
||||
req.origin_input_ids = [0]
|
||||
req.multimodal_inputs = None
|
||||
req.sampling_params.max_new_tokens = 0
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
)
|
||||
req.queue_time_start = time.perf_counter()
|
||||
self.waiting_queue.append(req)
|
||||
self._add_request_to_queue(req)
|
||||
return
|
||||
|
||||
# Validate prompts length
|
||||
@@ -1785,17 +1777,25 @@ class Scheduler(
|
||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||
|
||||
num_ready_reqs = 0
|
||||
num_abort_reqs = 0
|
||||
num_timeout_reqs = 0
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
if req.finished(): # It is aborted by AbortReq
|
||||
num_ready_reqs += 1
|
||||
continue
|
||||
req.grammar = req.grammar.result(timeout=0.03)
|
||||
if req.grammar:
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||
req.set_finish_with_abort(
|
||||
f"Invalid grammar request: {req.grammar_key=}"
|
||||
)
|
||||
num_ready_reqs += 1
|
||||
except futures._base.TimeoutError:
|
||||
req.grammar_wait_ct += 1
|
||||
# NOTE(lianmin): this timeout is the waiting time of the above line. It is
|
||||
# not the waiting time from it enters the grammar queue.
|
||||
if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
|
||||
num_abort_reqs = 1
|
||||
num_timeout_reqs = 1
|
||||
break
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
@@ -1807,28 +1807,33 @@ class Scheduler(
|
||||
|
||||
if tp_size > 1:
|
||||
# Sync across TP ranks to make sure they have the same number of ready requests
|
||||
tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
|
||||
tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
|
||||
torch.distributed.all_reduce(
|
||||
tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
|
||||
)
|
||||
num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()
|
||||
num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
|
||||
|
||||
for i in range(num_ready_reqs, num_ready_reqs_max):
|
||||
req = self.grammar_queue[i]
|
||||
if req.finished(): # It is aborted by AbortReq
|
||||
continue
|
||||
req.grammar = req.grammar.result()
|
||||
if req.grammar:
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
|
||||
if req.grammar is INVALID_GRAMMAR_OBJ:
|
||||
req.set_finish_with_abort(
|
||||
f"Invalid grammar request: {req.grammar_key=}"
|
||||
)
|
||||
else:
|
||||
num_ready_reqs_max = num_ready_reqs
|
||||
num_timeout_reqs_max = num_timeout_reqs
|
||||
|
||||
for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
|
||||
req = self.grammar_queue[i]
|
||||
req.grammar.cancel()
|
||||
req.grammar = None
|
||||
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
||||
logger.error(error_msg)
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
)
|
||||
num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
|
||||
for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
|
||||
req = self.grammar_queue[i]
|
||||
req.grammar.cancel()
|
||||
error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
|
||||
req.set_finish_with_abort(error_msg)
|
||||
self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
|
||||
num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
|
||||
|
||||
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
@@ -2024,8 +2029,6 @@ class Scheduler(
|
||||
)
|
||||
|
||||
def abort_request(self, recv_req: AbortReq):
|
||||
# TODO(lmzheng): abort the requests in the grammar queue.
|
||||
|
||||
# Delete requests in the waiting queue
|
||||
to_del = []
|
||||
for i, req in enumerate(self.waiting_queue):
|
||||
@@ -2047,8 +2050,16 @@ class Scheduler(
|
||||
for req in reqs:
|
||||
if req.rid.startswith(recv_req.rid) and not req.finished():
|
||||
logger.debug(f"Abort running request. {req.rid=}")
|
||||
# We must use to_abort because it is in a running batch
|
||||
req.to_abort = True
|
||||
|
||||
# Delete the requests in the grammar queue
|
||||
for req in self.grammar_queue:
|
||||
if req.rid.startswith(recv_req.rid):
|
||||
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
||||
req.grammar.cancel()
|
||||
req.set_finish_with_abort("Aborted by AbortReq.")
|
||||
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -221,7 +221,7 @@ class TokenizerManager:
|
||||
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
else:
|
||||
self.mm_processor = get_dummy_processor()
|
||||
self.mm_processor = None
|
||||
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
@@ -425,8 +425,8 @@ class TokenizerManager:
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
tokenized_obj = await self._tokenize_one_request(obj)
|
||||
self._send_one_request(obj, tokenized_obj, created_time)
|
||||
async for response in self._wait_one_response(obj, request):
|
||||
state = self._send_one_request(obj, tokenized_obj, created_time)
|
||||
async for response in self._wait_one_response(obj, state, request):
|
||||
yield response
|
||||
else:
|
||||
async for response in self._handle_batch_request(
|
||||
@@ -462,8 +462,7 @@ class TokenizerManager:
|
||||
)
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
|
||||
image_inputs: Optional[Dict] = None
|
||||
if obj.contains_mm_input():
|
||||
if self.mm_processor and obj.contains_mm_input():
|
||||
image_inputs = await self.mm_processor.process_mm_data_async(
|
||||
image_data=obj.image_data,
|
||||
input_text=input_text or input_ids,
|
||||
@@ -472,6 +471,8 @@ class TokenizerManager:
|
||||
)
|
||||
if image_inputs and "input_ids" in image_inputs:
|
||||
input_ids = image_inputs["input_ids"]
|
||||
else:
|
||||
image_inputs: Optional[Dict] = None
|
||||
|
||||
self._validate_token_len(obj, input_ids)
|
||||
return self._create_tokenized_object(
|
||||
@@ -631,15 +632,15 @@ class TokenizerManager:
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
||||
self.rid_to_state[obj.rid] = state
|
||||
return state
|
||||
|
||||
async def _wait_one_response(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
state: ReqState,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
):
|
||||
"""Wait for the response of one request."""
|
||||
state = self.rid_to_state[obj.rid]
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
@@ -709,16 +710,16 @@ class TokenizerManager:
|
||||
|
||||
for i, tokenized_obj in enumerate(tokenized_objs):
|
||||
tmp_obj = obj[i]
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, state, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# Sequential tokenization and processing
|
||||
for i in range(batch_size):
|
||||
tmp_obj = obj[i]
|
||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, state, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||
@@ -743,8 +744,8 @@ class TokenizerManager:
|
||||
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
||||
tokenized_obj.sampling_params.max_new_tokens = 0
|
||||
tokenized_obj.stream = False
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
await self._wait_one_response(tmp_obj, request).__anext__()
|
||||
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
await self._wait_one_response(tmp_obj, state, request).__anext__()
|
||||
|
||||
# Expand requests, assign new rids for them, and send them
|
||||
for i in range(batch_size):
|
||||
@@ -752,8 +753,8 @@ class TokenizerManager:
|
||||
tmp_obj = copy.copy(objs[i])
|
||||
tokenized_obj = copy.copy(tokenized_objs[i])
|
||||
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
||||
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, request))
|
||||
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, state, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
|
||||
# Wait for all requests
|
||||
@@ -789,6 +790,9 @@ class TokenizerManager:
|
||||
req = AbortReq(rid)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
if self.enable_metrics:
|
||||
self.metrics_collector.observe_one_aborted_request()
|
||||
|
||||
async def start_profile(
|
||||
self,
|
||||
output_dir: Optional[str] = None,
|
||||
|
||||
@@ -35,10 +35,6 @@ def validate_input_length(
|
||||
f"the maximum allowed length ({max_req_input_len} tokens). "
|
||||
f"Use a shorter input or enable --allow-auto-truncate."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||
)
|
||||
return error_msg
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user