From 747dd45077c57db11455b3d9071ebc0d357f97de Mon Sep 17 00:00:00 2001 From: harrisonlimh <97203667+harrisonlimh@users.noreply.github.com> Date: Mon, 28 Jul 2025 07:32:33 -0700 Subject: [PATCH] feat: throttle requests at scheduler based on --max_queued_requests (#7565) --- python/sglang/srt/entrypoints/http_server.py | 14 ++- .../srt/entrypoints/openai/serving_base.py | 7 +- python/sglang/srt/managers/io_struct.py | 2 + python/sglang/srt/managers/scheduler.py | 19 ++++ .../sglang/srt/managers/tokenizer_manager.py | 28 +++++- python/sglang/srt/managers/tp_worker.py | 5 ++ python/sglang/srt/server_args.py | 8 ++ python/sglang/test/test_utils.py | 53 +++++++++++ test/srt/run_suite.py | 1 + test/srt/test_request_queue_validation.py | 87 +++++++++++++++++++ 10 files changed, 218 insertions(+), 6 deletions(-) create mode 100644 test/srt/test_request_queue_validation.py diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index e2ce86847..586a26495 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -38,7 +38,7 @@ import orjson import requests import uvicorn import uvloop -from fastapi import Depends, FastAPI, Request, UploadFile +from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import ORJSONResponse, Response, StreamingResponse @@ -174,6 +174,18 @@ app.add_middleware( ) +@app.exception_handler(HTTPException) +async def validation_exception_handler(request: Request, exc: HTTPException): + """Enrich HTTP exception with status code and other details""" + error = ErrorResponse( + object="error", + message=exc.detail, + type=str(exc.status_code), + code=exc.status_code, + ) + return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code) + + # Custom exception handlers to change validation error status codes @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index ba7514f0d..ad7c35f20 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -4,7 +4,7 @@ import uuid from abc import ABC, abstractmethod from typing import Any, Optional, Union -from fastapi import Request +from fastapi import HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest @@ -45,7 +45,10 @@ class OpenAIServingBase(ABC): return await self._handle_non_streaming_request( adapted_request, processed_request, raw_request ) - + except HTTPException as e: + return self.create_error_response( + message=e.detail, err_type=str(e.status_code), status_code=e.status_code + ) except Exception as e: logger.exception(f"Error in request: {e}") return self.create_error_response( diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 3d18e1af4..377205e67 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -911,6 +911,8 @@ class AbortReq: rid: str = "" # Whether to abort all requests abort_all: bool = False + # The finished reason data + finished_reason: Optional[Dict[str, Any]] = None @dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ecfce1392..5d3d115e2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -24,6 +24,7 @@ import time from collections import defaultdict, deque from concurrent import futures from dataclasses import dataclass +from http import HTTPStatus from pathlib import Path from types import SimpleNamespace from typing import Dict, List, Optional, Tuple, Union @@ -370,6 +371,7 @@ class Scheduler( self.max_total_num_tokens, self.max_prefill_tokens, self.max_running_requests, + self.max_queued_requests, self.max_req_len, self.max_req_input_len, self.random_seed, @@ -1086,6 +1088,19 @@ class Scheduler( self.return_health_check_ct += 1 continue + # If it is a work request, accept or reject the request based on the request queue size. + if is_work_request(recv_req): + if len(self.waiting_queue) + 1 > self.max_queued_requests: + abort_req = AbortReq( + recv_req.rid, + finished_reason={ + "type": "abort", + "status_code": HTTPStatus.SERVICE_UNAVAILABLE, + "message": "The request queue is full.", + }, + ) + self.send_to_tokenizer.send_pyobj(abort_req) + continue output = self._request_dispatcher(recv_req) if output is not None: if isinstance(output, RpcReqOutput): @@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req): return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK") +def is_work_request(recv_req): + return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)) + + def _export_static_state(model): return dict( buffers=[ diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index cb4df6b65..c998b51c9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -766,6 +766,19 @@ class TokenizerManager: ): raise ValueError(finish_reason["message"]) + if ( + finish_reason.get("type") == "abort" + and finish_reason.get("status_code") + == HTTPStatus.SERVICE_UNAVAILABLE + ): + # This is an abort request initiated by scheduler. + # Delete the key to prevent resending abort request to the scheduler and + # to ensure aborted request state is cleaned up. + del self.rid_to_state[state.obj.rid] + raise fastapi.HTTPException( + status_code=finish_reason["status_code"], + detail=finish_reason["message"], + ) yield out break @@ -1705,8 +1718,15 @@ class TokenizerManager: def _handle_abort_req(self, recv_obj): state = self.rid_to_state[recv_obj.rid] state.finished = True - state.out_list.append( - { + if recv_obj.finished_reason: + out = { + "meta_info": { + "id": recv_obj.rid, + "finish_reason": recv_obj.finished_reason, + }, + } + else: + out = { "text": "", "meta_info": { "id": recv_obj.rid, @@ -1718,7 +1738,7 @@ class TokenizerManager: "completion_tokens": 0, }, } - ) + state.out_list.append(out) state.event.set() def _handle_open_session_req_output(self, recv_obj): @@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]): # # | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state | # | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- | +# | http | yes | validation | background task | fast api | del in _handle_abort_req | # | http | yes | waiting queue | background task | fast api | del in _handle_abort_req | # | http | yes | running | background task | fast api | del in _handle_batch_output | +# | http | no | validation | http exception | http exception | del in _handle_abort_req | # | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req | # | http | no | running | type 3 | type 3 exception | del in _handle_batch_output | # diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index e6d3c9a24..42ed45949 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -130,6 +130,10 @@ class TpModelWorker: self.model_runner.req_to_token_pool.size, ) assert self.max_running_requests > 0, "max_running_request is zero" + self.max_queued_requests = server_args.max_queued_requests + assert ( + self.max_running_requests > 0 + ), "max_queued_requests is zero. We need to be at least 1 to schedule a request." self.max_req_len = min( self.model_config.context_len - 1, self.max_total_num_tokens - 1, @@ -165,6 +169,7 @@ class TpModelWorker: self.max_total_num_tokens, self.max_prefill_tokens, self.max_running_requests, + self.max_queued_requests, self.max_req_len, self.max_req_input_len, self.random_seed, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 54dc76ed7..dc0c6cd1a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -19,6 +19,7 @@ import json import logging import os import random +import sys import tempfile from typing import List, Literal, Optional, Union @@ -74,6 +75,7 @@ class ServerArgs: # Memory and scheduling mem_fraction_static: Optional[float] = None max_running_requests: Optional[int] = None + max_queued_requests: Optional[int] = sys.maxsize max_total_tokens: Optional[int] = None chunked_prefill_size: Optional[int] = None max_prefill_tokens: int = 16384 @@ -805,6 +807,12 @@ class ServerArgs: default=ServerArgs.max_running_requests, help="The maximum number of running requests.", ) + parser.add_argument( + "--max-queued-requests", + type=int, + default=ServerArgs.max_queued_requests, + help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.", + ) parser.add_argument( "--max-total-tokens", type=int, diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 65d989eab..c155a4d6d 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -19,6 +19,7 @@ from pathlib import Path from types import SimpleNamespace from typing import Awaitable, Callable, List, Optional, Tuple +import aiohttp import numpy as np import requests import torch @@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple): raise +def send_generate_requests(base_url: str, num_requests: int) -> List[str]: + """Sends generate request serially and returns status codes. Max concurrency is 1.""" + + def generate(): + prompt = """ + System: You are a helpful assistant. + User: What is the capital of France? + Assistant: The capital of France is + """ + response = requests.post( + f"{base_url}/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 50, + }, + }, + ) + return response.status_code + + return [generate() for _ in range(num_requests)] + + +async def send_concurrent_generate_requests( + base_url: str, num_requests: int +) -> List[str]: + """Sends generate request concurrently and returns status codes. Max concurrency is num_requests.""" + + async def async_generate(): + async with aiohttp.ClientSession() as session: + prompt = """ + System: You are a helpful assistant. + User: What is the capital of France? + Assistant: The capital of France is + """ + async with session.post( + f"{base_url}/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 50, + }, + }, + ) as response: + return response.status + + tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)] + return await asyncio.gather(*tasks) + + class CustomTestCase(unittest.TestCase): def _callTestMethod(self, method): max_retry = int( diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c9876e161..7b43d5175 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -86,6 +86,7 @@ suites = { TestFile("test_radix_attention.py", 105), TestFile("test_regex_constrained.py", 64), TestFile("test_retract_decode.py", 54), + TestFile("test_request_queue_validation.py", 30), TestFile("test_server_args.py", 1), TestFile("test_skip_tokenizer_init.py", 117), TestFile("test_srt_engine.py", 261), diff --git a/test/srt/test_request_queue_validation.py b/test/srt/test_request_queue_validation.py new file mode 100644 index 000000000..2a9739a1c --- /dev/null +++ b/test/srt/test_request_queue_validation.py @@ -0,0 +1,87 @@ +import asyncio +import os +import re +import unittest +from concurrent.futures import ThreadPoolExecutor + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + STDERR_FILENAME, + STDOUT_FILENAME, + CustomTestCase, + popen_launch_server, + send_concurrent_generate_requests, + send_generate_requests, +) + + +class TestMaxQueuedRequests(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + cls.stdout = open(STDOUT_FILENAME, "w") + cls.stderr = open(STDERR_FILENAME, "w") + + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--max-running-requests", # Enforce max request concurrency is 1 + "1", + "--max-queued-requests", # Enforce max queued request number is 1 + "1", + ), + return_stdout_stderr=(cls.stdout, cls.stderr), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + cls.stdout.close() + cls.stderr.close() + os.remove(STDOUT_FILENAME) + os.remove(STDERR_FILENAME) + + def test_max_queued_requests_validation_with_serial_requests(self): + """Verify request is not throttled when the max concurrency is 1.""" + status_codes = send_generate_requests( + self.base_url, + num_requests=10, + ) + + for status_code in status_codes: + assert status_code == 200 # request shouldn't be throttled + + def test_max_queued_requests_validation_with_concurrent_requests(self): + """Verify request throttling with concurrent requests.""" + status_codes = asyncio.run( + send_concurrent_generate_requests(self.base_url, num_requests=10) + ) + + assert 200 in status_codes + assert 503 in status_codes + assert all(status_code in [200, 503] for status_code in status_codes) + + def test_max_running_requests_and_max_queued_request_validation(self): + """Verify running request and queued request numbers based on server logs.""" + rr_pattern = re.compile(r"#running-req:\s*(\d+)") + qr_pattern = re.compile(r"#queue-req:\s*(\d+)") + + with open(STDERR_FILENAME) as lines: + for line in lines: + rr_match, qr_match = rr_pattern.search(line), qr_pattern.search(line) + if rr_match: + assert int(rr_match.group(1)) <= 1 + if qr_match: + assert int(qr_match.group(1)) <= 1 + + +if __name__ == "__main__": + unittest.main()