feat: throttle requests at scheduler based on --max_queued_requests (#7565)

This commit is contained in:
harrisonlimh
2025-07-28 07:32:33 -07:00
committed by GitHub
parent b582159246
commit 747dd45077
10 changed files with 218 additions and 6 deletions

View File

@@ -38,7 +38,7 @@ import orjson
import requests
import uvicorn
import uvloop
from fastapi import Depends, FastAPI, Request, UploadFile
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
@@ -174,6 +174,18 @@ app.add_middleware(
)
@app.exception_handler(HTTPException)
async def validation_exception_handler(request: Request, exc: HTTPException):
"""Enrich HTTP exception with status code and other details"""
error = ErrorResponse(
object="error",
message=exc.detail,
type=str(exc.status_code),
code=exc.status_code,
)
return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)
# Custom exception handlers to change validation error status codes
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):

View File

@@ -4,7 +4,7 @@ import uuid
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from fastapi import Request
from fastapi import HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
return await self._handle_non_streaming_request(
adapted_request, processed_request, raw_request
)
except HTTPException as e:
return self.create_error_response(
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
)
except Exception as e:
logger.exception(f"Error in request: {e}")
return self.create_error_response(

View File

@@ -911,6 +911,8 @@ class AbortReq:
rid: str = ""
# Whether to abort all requests
abort_all: bool = False
# The finished reason data
finished_reason: Optional[Dict[str, Any]] = None
@dataclass

View File

@@ -24,6 +24,7 @@ import time
from collections import defaultdict, deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
@@ -370,6 +371,7 @@ class Scheduler(
self.max_total_num_tokens,
self.max_prefill_tokens,
self.max_running_requests,
self.max_queued_requests,
self.max_req_len,
self.max_req_input_len,
self.random_seed,
@@ -1086,6 +1088,19 @@ class Scheduler(
self.return_health_check_ct += 1
continue
# If it is a work request, accept or reject the request based on the request queue size.
if is_work_request(recv_req):
if len(self.waiting_queue) + 1 > self.max_queued_requests:
abort_req = AbortReq(
recv_req.rid,
finished_reason={
"type": "abort",
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
"message": "The request queue is full.",
},
)
self.send_to_tokenizer.send_pyobj(abort_req)
continue
output = self._request_dispatcher(recv_req)
if output is not None:
if isinstance(output, RpcReqOutput):
@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
def is_work_request(recv_req):
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
def _export_static_state(model):
return dict(
buffers=[

View File

@@ -766,6 +766,19 @@ class TokenizerManager:
):
raise ValueError(finish_reason["message"])
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code")
== HTTPStatus.SERVICE_UNAVAILABLE
):
# This is an abort request initiated by scheduler.
# Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up.
del self.rid_to_state[state.obj.rid]
raise fastapi.HTTPException(
status_code=finish_reason["status_code"],
detail=finish_reason["message"],
)
yield out
break
@@ -1705,8 +1718,15 @@ class TokenizerManager:
def _handle_abort_req(self, recv_obj):
state = self.rid_to_state[recv_obj.rid]
state.finished = True
state.out_list.append(
{
if recv_obj.finished_reason:
out = {
"meta_info": {
"id": recv_obj.rid,
"finish_reason": recv_obj.finished_reason,
},
}
else:
out = {
"text": "",
"meta_info": {
"id": recv_obj.rid,
@@ -1718,7 +1738,7 @@ class TokenizerManager:
"completion_tokens": 0,
},
}
)
state.out_list.append(out)
state.event.set()
def _handle_open_session_req_output(self, recv_obj):
@@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]):
#
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
# | http | yes | running | background task | fast api | del in _handle_batch_output |
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
#

View File

@@ -130,6 +130,10 @@ class TpModelWorker:
self.model_runner.req_to_token_pool.size,
)
assert self.max_running_requests > 0, "max_running_request is zero"
self.max_queued_requests = server_args.max_queued_requests
assert (
self.max_running_requests > 0
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
self.max_req_len = min(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
@@ -165,6 +169,7 @@ class TpModelWorker:
self.max_total_num_tokens,
self.max_prefill_tokens,
self.max_running_requests,
self.max_queued_requests,
self.max_req_len,
self.max_req_input_len,
self.random_seed,

View File

@@ -19,6 +19,7 @@ import json
import logging
import os
import random
import sys
import tempfile
from typing import List, Literal, Optional, Union
@@ -74,6 +75,7 @@ class ServerArgs:
# Memory and scheduling
mem_fraction_static: Optional[float] = None
max_running_requests: Optional[int] = None
max_queued_requests: Optional[int] = sys.maxsize
max_total_tokens: Optional[int] = None
chunked_prefill_size: Optional[int] = None
max_prefill_tokens: int = 16384
@@ -805,6 +807,12 @@ class ServerArgs:
default=ServerArgs.max_running_requests,
help="The maximum number of running requests.",
)
parser.add_argument(
"--max-queued-requests",
type=int,
default=ServerArgs.max_queued_requests,
help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
)
parser.add_argument(
"--max-total-tokens",
type=int,

View File

@@ -19,6 +19,7 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple
import aiohttp
import numpy as np
import requests
import torch
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
raise
def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
def generate():
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
},
)
return response.status_code
return [generate() for _ in range(num_requests)]
async def send_concurrent_generate_requests(
base_url: str, num_requests: int
) -> List[str]:
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
async def async_generate():
async with aiohttp.ClientSession() as session:
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
async with session.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
},
) as response:
return response.status
tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
return await asyncio.gather(*tasks)
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
max_retry = int(