feat: throttle requests at scheduler based on --max_queued_requests (#7565)
This commit is contained in:
@@ -38,7 +38,7 @@ import orjson
|
|||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import Depends, FastAPI, Request, UploadFile
|
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||||
@@ -174,6 +174,18 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def validation_exception_handler(request: Request, exc: HTTPException):
|
||||||
|
"""Enrich HTTP exception with status code and other details"""
|
||||||
|
error = ErrorResponse(
|
||||||
|
object="error",
|
||||||
|
message=exc.detail,
|
||||||
|
type=str(exc.status_code),
|
||||||
|
code=exc.status_code,
|
||||||
|
)
|
||||||
|
return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)
|
||||||
|
|
||||||
|
|
||||||
# Custom exception handlers to change validation error status codes
|
# Custom exception handlers to change validation error status codes
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import HTTPException, Request
|
||||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||||
|
|
||||||
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
||||||
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
|
|||||||
return await self._handle_non_streaming_request(
|
return await self._handle_non_streaming_request(
|
||||||
adapted_request, processed_request, raw_request
|
adapted_request, processed_request, raw_request
|
||||||
)
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
return self.create_error_response(
|
||||||
|
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error in request: {e}")
|
logger.exception(f"Error in request: {e}")
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
|
|||||||
@@ -911,6 +911,8 @@ class AbortReq:
|
|||||||
rid: str = ""
|
rid: str = ""
|
||||||
# Whether to abort all requests
|
# Whether to abort all requests
|
||||||
abort_all: bool = False
|
abort_all: bool = False
|
||||||
|
# The finished reason data
|
||||||
|
finished_reason: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import time
|
|||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from http import HTTPStatus
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
@@ -370,6 +371,7 @@ class Scheduler(
|
|||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.max_running_requests,
|
self.max_running_requests,
|
||||||
|
self.max_queued_requests,
|
||||||
self.max_req_len,
|
self.max_req_len,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
@@ -1086,6 +1088,19 @@ class Scheduler(
|
|||||||
self.return_health_check_ct += 1
|
self.return_health_check_ct += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# If it is a work request, accept or reject the request based on the request queue size.
|
||||||
|
if is_work_request(recv_req):
|
||||||
|
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
||||||
|
abort_req = AbortReq(
|
||||||
|
recv_req.rid,
|
||||||
|
finished_reason={
|
||||||
|
"type": "abort",
|
||||||
|
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||||
|
"message": "The request queue is full.",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||||
|
continue
|
||||||
output = self._request_dispatcher(recv_req)
|
output = self._request_dispatcher(recv_req)
|
||||||
if output is not None:
|
if output is not None:
|
||||||
if isinstance(output, RpcReqOutput):
|
if isinstance(output, RpcReqOutput):
|
||||||
@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
|
|||||||
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
||||||
|
|
||||||
|
|
||||||
|
def is_work_request(recv_req):
|
||||||
|
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
||||||
|
|
||||||
|
|
||||||
def _export_static_state(model):
|
def _export_static_state(model):
|
||||||
return dict(
|
return dict(
|
||||||
buffers=[
|
buffers=[
|
||||||
|
|||||||
@@ -766,6 +766,19 @@ class TokenizerManager:
|
|||||||
):
|
):
|
||||||
raise ValueError(finish_reason["message"])
|
raise ValueError(finish_reason["message"])
|
||||||
|
|
||||||
|
if (
|
||||||
|
finish_reason.get("type") == "abort"
|
||||||
|
and finish_reason.get("status_code")
|
||||||
|
== HTTPStatus.SERVICE_UNAVAILABLE
|
||||||
|
):
|
||||||
|
# This is an abort request initiated by scheduler.
|
||||||
|
# Delete the key to prevent resending abort request to the scheduler and
|
||||||
|
# to ensure aborted request state is cleaned up.
|
||||||
|
del self.rid_to_state[state.obj.rid]
|
||||||
|
raise fastapi.HTTPException(
|
||||||
|
status_code=finish_reason["status_code"],
|
||||||
|
detail=finish_reason["message"],
|
||||||
|
)
|
||||||
yield out
|
yield out
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -1705,8 +1718,15 @@ class TokenizerManager:
|
|||||||
def _handle_abort_req(self, recv_obj):
|
def _handle_abort_req(self, recv_obj):
|
||||||
state = self.rid_to_state[recv_obj.rid]
|
state = self.rid_to_state[recv_obj.rid]
|
||||||
state.finished = True
|
state.finished = True
|
||||||
state.out_list.append(
|
if recv_obj.finished_reason:
|
||||||
{
|
out = {
|
||||||
|
"meta_info": {
|
||||||
|
"id": recv_obj.rid,
|
||||||
|
"finish_reason": recv_obj.finished_reason,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
out = {
|
||||||
"text": "",
|
"text": "",
|
||||||
"meta_info": {
|
"meta_info": {
|
||||||
"id": recv_obj.rid,
|
"id": recv_obj.rid,
|
||||||
@@ -1718,7 +1738,7 @@ class TokenizerManager:
|
|||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
state.out_list.append(out)
|
||||||
state.event.set()
|
state.event.set()
|
||||||
|
|
||||||
def _handle_open_session_req_output(self, recv_obj):
|
def _handle_open_session_req_output(self, recv_obj):
|
||||||
@@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]):
|
|||||||
#
|
#
|
||||||
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
||||||
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
||||||
|
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
|
||||||
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
||||||
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
||||||
|
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
|
||||||
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
||||||
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -130,6 +130,10 @@ class TpModelWorker:
|
|||||||
self.model_runner.req_to_token_pool.size,
|
self.model_runner.req_to_token_pool.size,
|
||||||
)
|
)
|
||||||
assert self.max_running_requests > 0, "max_running_request is zero"
|
assert self.max_running_requests > 0, "max_running_request is zero"
|
||||||
|
self.max_queued_requests = server_args.max_queued_requests
|
||||||
|
assert (
|
||||||
|
self.max_running_requests > 0
|
||||||
|
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
||||||
self.max_req_len = min(
|
self.max_req_len = min(
|
||||||
self.model_config.context_len - 1,
|
self.model_config.context_len - 1,
|
||||||
self.max_total_num_tokens - 1,
|
self.max_total_num_tokens - 1,
|
||||||
@@ -165,6 +169,7 @@ class TpModelWorker:
|
|||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
self.max_running_requests,
|
self.max_running_requests,
|
||||||
|
self.max_queued_requests,
|
||||||
self.max_req_len,
|
self.max_req_len,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
@@ -74,6 +75,7 @@ class ServerArgs:
|
|||||||
# Memory and scheduling
|
# Memory and scheduling
|
||||||
mem_fraction_static: Optional[float] = None
|
mem_fraction_static: Optional[float] = None
|
||||||
max_running_requests: Optional[int] = None
|
max_running_requests: Optional[int] = None
|
||||||
|
max_queued_requests: Optional[int] = sys.maxsize
|
||||||
max_total_tokens: Optional[int] = None
|
max_total_tokens: Optional[int] = None
|
||||||
chunked_prefill_size: Optional[int] = None
|
chunked_prefill_size: Optional[int] = None
|
||||||
max_prefill_tokens: int = 16384
|
max_prefill_tokens: int = 16384
|
||||||
@@ -805,6 +807,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.max_running_requests,
|
default=ServerArgs.max_running_requests,
|
||||||
help="The maximum number of running requests.",
|
help="The maximum number of running requests.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-queued-requests",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.max_queued_requests,
|
||||||
|
help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-total-tokens",
|
"--max-total-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
|
||||||
|
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
prompt = """
|
||||||
|
System: You are a helpful assistant.
|
||||||
|
User: What is the capital of France?
|
||||||
|
Assistant: The capital of France is
|
||||||
|
"""
|
||||||
|
response = requests.post(
|
||||||
|
f"{base_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 50,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response.status_code
|
||||||
|
|
||||||
|
return [generate() for _ in range(num_requests)]
|
||||||
|
|
||||||
|
|
||||||
|
async def send_concurrent_generate_requests(
|
||||||
|
base_url: str, num_requests: int
|
||||||
|
) -> List[str]:
|
||||||
|
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
|
||||||
|
|
||||||
|
async def async_generate():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
prompt = """
|
||||||
|
System: You are a helpful assistant.
|
||||||
|
User: What is the capital of France?
|
||||||
|
Assistant: The capital of France is
|
||||||
|
"""
|
||||||
|
async with session.post(
|
||||||
|
f"{base_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 50,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
) as response:
|
||||||
|
return response.status
|
||||||
|
|
||||||
|
tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
class CustomTestCase(unittest.TestCase):
|
class CustomTestCase(unittest.TestCase):
|
||||||
def _callTestMethod(self, method):
|
def _callTestMethod(self, method):
|
||||||
max_retry = int(
|
max_retry = int(
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ suites = {
|
|||||||
TestFile("test_radix_attention.py", 105),
|
TestFile("test_radix_attention.py", 105),
|
||||||
TestFile("test_regex_constrained.py", 64),
|
TestFile("test_regex_constrained.py", 64),
|
||||||
TestFile("test_retract_decode.py", 54),
|
TestFile("test_retract_decode.py", 54),
|
||||||
|
TestFile("test_request_queue_validation.py", 30),
|
||||||
TestFile("test_server_args.py", 1),
|
TestFile("test_server_args.py", 1),
|
||||||
TestFile("test_skip_tokenizer_init.py", 117),
|
TestFile("test_skip_tokenizer_init.py", 117),
|
||||||
TestFile("test_srt_engine.py", 261),
|
TestFile("test_srt_engine.py", 261),
|
||||||
|
|||||||
87
test/srt/test_request_queue_validation.py
Normal file
87
test/srt/test_request_queue_validation.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
STDERR_FILENAME,
|
||||||
|
STDOUT_FILENAME,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
send_concurrent_generate_requests,
|
||||||
|
send_generate_requests,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMaxQueuedRequests(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
|
||||||
|
cls.stdout = open(STDOUT_FILENAME, "w")
|
||||||
|
cls.stderr = open(STDERR_FILENAME, "w")
|
||||||
|
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=(
|
||||||
|
"--max-running-requests", # Enforce max request concurrency is 1
|
||||||
|
"1",
|
||||||
|
"--max-queued-requests", # Enforce max queued request number is 1
|
||||||
|
"1",
|
||||||
|
),
|
||||||
|
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
cls.stdout.close()
|
||||||
|
cls.stderr.close()
|
||||||
|
os.remove(STDOUT_FILENAME)
|
||||||
|
os.remove(STDERR_FILENAME)
|
||||||
|
|
||||||
|
def test_max_queued_requests_validation_with_serial_requests(self):
|
||||||
|
"""Verify request is not throttled when the max concurrency is 1."""
|
||||||
|
status_codes = send_generate_requests(
|
||||||
|
self.base_url,
|
||||||
|
num_requests=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
for status_code in status_codes:
|
||||||
|
assert status_code == 200 # request shouldn't be throttled
|
||||||
|
|
||||||
|
def test_max_queued_requests_validation_with_concurrent_requests(self):
|
||||||
|
"""Verify request throttling with concurrent requests."""
|
||||||
|
status_codes = asyncio.run(
|
||||||
|
send_concurrent_generate_requests(self.base_url, num_requests=10)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert 200 in status_codes
|
||||||
|
assert 503 in status_codes
|
||||||
|
assert all(status_code in [200, 503] for status_code in status_codes)
|
||||||
|
|
||||||
|
def test_max_running_requests_and_max_queued_request_validation(self):
|
||||||
|
"""Verify running request and queued request numbers based on server logs."""
|
||||||
|
rr_pattern = re.compile(r"#running-req:\s*(\d+)")
|
||||||
|
qr_pattern = re.compile(r"#queue-req:\s*(\d+)")
|
||||||
|
|
||||||
|
with open(STDERR_FILENAME) as lines:
|
||||||
|
for line in lines:
|
||||||
|
rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line)
|
||||||
|
if rr_match:
|
||||||
|
assert int(rr_match.group(1)) <= 1
|
||||||
|
if qr_match:
|
||||||
|
assert int(qr_match.group(1)) <= 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user