Frontend: better error message handling for FINISH_ABORT in scheduler.py (#2956)
This commit is contained in:
@@ -115,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
|
|||||||
|
|
||||||
|
|
||||||
class FINISH_ABORT(BaseFinishReason):
|
class FINISH_ABORT(BaseFinishReason):
|
||||||
def __init__(self, message="Unknown error"):
|
def __init__(self, message="Unknown error", status_code=None, err_type=None):
|
||||||
super().__init__(is_error=True)
|
super().__init__(is_error=True)
|
||||||
self.message = message
|
self.message = message
|
||||||
|
self.status_code = status_code
|
||||||
|
self.err_type = err_type
|
||||||
|
|
||||||
def to_json(self):
|
def to_json(self):
|
||||||
return {
|
return {
|
||||||
"type": "abort",
|
"type": "abort",
|
||||||
"message": self.message,
|
"message": self.message,
|
||||||
|
"status_code": self.status_code,
|
||||||
|
"err_type": self.err_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import warnings
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from http import HTTPStatus
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -672,15 +673,16 @@ class Scheduler:
|
|||||||
req.extend_image_inputs(image_inputs)
|
req.extend_image_inputs(image_inputs)
|
||||||
|
|
||||||
if len(req.origin_input_ids) >= self.max_req_input_len:
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
||||||
logger.error(
|
error_msg = (
|
||||||
"Multimodal prompt is too long after expanding multimodal tokens. "
|
"Multimodal prompt is too long after expanding multimodal tokens. "
|
||||||
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
||||||
)
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
req.origin_input_ids = [0]
|
req.origin_input_ids = [0]
|
||||||
req.image_inputs = None
|
req.image_inputs = None
|
||||||
req.sampling_params.max_new_tokens = 0
|
req.sampling_params.max_new_tokens = 0
|
||||||
req.finished_reason = FINISH_ABORT(
|
req.finished_reason = FINISH_ABORT(
|
||||||
"Multimodal prompt is too long. Check server logs for details."
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||||
)
|
)
|
||||||
self.waiting_queue.append(req)
|
self.waiting_queue.append(req)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
@@ -384,6 +385,16 @@ class TokenizerManager:
|
|||||||
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
del self.rid_to_state[obj.rid]
|
del self.rid_to_state[obj.rid]
|
||||||
|
|
||||||
|
# Check if this was an abort/error created by scheduler
|
||||||
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
||||||
|
finish_reason = out["meta_info"]["finish_reason"]
|
||||||
|
if (
|
||||||
|
finish_reason.get("type") == "abort"
|
||||||
|
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
|
||||||
|
):
|
||||||
|
raise ValueError(finish_reason["message"])
|
||||||
|
|
||||||
yield out
|
yield out
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||||
@@ -35,7 +36,9 @@ def validate_input_length(
|
|||||||
f"Use a shorter input or enable --allow-auto-truncate."
|
f"Use a shorter input or enable --allow-auto-truncate."
|
||||||
)
|
)
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
req.finished_reason = FINISH_ABORT(error_msg)
|
req.finished_reason = FINISH_ABORT(
|
||||||
|
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
||||||
|
)
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -392,34 +392,33 @@ class TestQWen2VLServerContextLengthIssue(unittest.TestCase):
|
|||||||
def test_chat_completion(self):
|
def test_chat_completion(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
with self.assertRaises(openai.BadRequestError) as cm:
|
||||||
model="default",
|
client.chat.completions.create(
|
||||||
messages=[
|
model="default",
|
||||||
{
|
messages=[
|
||||||
"role": "user",
|
{
|
||||||
"content": [
|
"role": "user",
|
||||||
{
|
"content": [
|
||||||
"type": "image_url",
|
{
|
||||||
"image_url": {
|
"type": "image_url",
|
||||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
"image_url": {
|
||||||
|
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
{
|
||||||
{
|
"type": "text",
|
||||||
"type": "text",
|
"text": "Give a lengthy description of this picture",
|
||||||
"text": "Give a lengthy description of this picture",
|
},
|
||||||
},
|
],
|
||||||
],
|
},
|
||||||
},
|
],
|
||||||
],
|
temperature=0,
|
||||||
temperature=0,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
assert response.choices[0].finish_reason == "abort"
|
self.assertIn(
|
||||||
assert response.id
|
"Multimodal prompt is too long after expanding multimodal tokens.",
|
||||||
assert response.created
|
str(cm.exception),
|
||||||
assert response.usage.prompt_tokens > 0
|
)
|
||||||
assert response.usage.completion_tokens > 0
|
|
||||||
assert response.usage.total_tokens > 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestMllamaServer(TestOpenAIVisionServer):
|
class TestMllamaServer(TestOpenAIVisionServer):
|
||||||
|
|||||||
Reference in New Issue
Block a user