Frontend: better error message handling for FINISH_ABORT in scheduler.py (#2956)

This commit is contained in:
Chang Su
2025-01-18 19:37:30 -08:00
committed by GitHub
parent 2bd18e2d76
commit 4d4cdb3fe7
5 changed files with 50 additions and 31 deletions

View File

@@ -115,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
class FINISH_ABORT(BaseFinishReason): class FINISH_ABORT(BaseFinishReason):
def __init__(self, message="Unknown error"): def __init__(self, message="Unknown error", status_code=None, err_type=None):
super().__init__(is_error=True) super().__init__(is_error=True)
self.message = message self.message = message
self.status_code = status_code
self.err_type = err_type
def to_json(self): def to_json(self):
return { return {
"type": "abort", "type": "abort",
"message": self.message, "message": self.message,
"status_code": self.status_code,
"err_type": self.err_type,
} }

View File

@@ -23,6 +23,7 @@ import warnings
from collections import deque from collections import deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
@@ -672,15 +673,16 @@ class Scheduler:
req.extend_image_inputs(image_inputs) req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
logger.error( error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. " "Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. " f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
) )
logger.error(error_msg)
req.origin_input_ids = [0] req.origin_input_ids = [0]
req.image_inputs = None req.image_inputs = None
req.sampling_params.max_new_tokens = 0 req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
"Multimodal prompt is too long. Check server logs for details." error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
) )
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return

View File

@@ -25,6 +25,7 @@ import threading
import time import time
import uuid import uuid
from datetime import datetime from datetime import datetime
from http import HTTPStatus
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi import fastapi
@@ -384,6 +385,16 @@ class TokenizerManager:
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}" msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(msg) logger.info(msg)
del self.rid_to_state[obj.rid] del self.rid_to_state[obj.rid]
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
):
raise ValueError(finish_reason["message"])
yield out yield out
break break

View File

@@ -1,4 +1,5 @@
import logging import logging
from http import HTTPStatus
from typing import Optional from typing import Optional
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
@@ -35,7 +36,9 @@ def validate_input_length(
f"Use a shorter input or enable --allow-auto-truncate." f"Use a shorter input or enable --allow-auto-truncate."
) )
logger.error(error_msg) logger.error(error_msg)
req.finished_reason = FINISH_ABORT(error_msg) req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
return error_msg return error_msg
return None return None

View File

@@ -392,34 +392,33 @@ class TestQWen2VLServerContextLengthIssue(unittest.TestCase):
def test_chat_completion(self): def test_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create( with self.assertRaises(openai.BadRequestError) as cm:
model="default", client.chat.completions.create(
messages=[ model="default",
{ messages=[
"role": "user", {
"content": [ "role": "user",
{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" "image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
}, {
{ "type": "text",
"type": "text", "text": "Give a lengthy description of this picture",
"text": "Give a lengthy description of this picture", },
}, ],
], },
}, ],
], temperature=0,
temperature=0, )
)
assert response.choices[0].finish_reason == "abort" self.assertIn(
assert response.id "Multimodal prompt is too long after expanding multimodal tokens.",
assert response.created str(cm.exception),
assert response.usage.prompt_tokens > 0 )
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
class TestMllamaServer(TestOpenAIVisionServer): class TestMllamaServer(TestOpenAIVisionServer):