feat: throttle requests at scheduler based on --max_queued_requests (#7565)
This commit is contained in:
@@ -38,7 +38,7 @@ import orjson
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import Depends, FastAPI, Request, UploadFile
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
@@ -174,6 +174,18 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def validation_exception_handler(request: Request, exc: HTTPException):
|
||||
"""Enrich HTTP exception with status code and other details"""
|
||||
error = ErrorResponse(
|
||||
object="error",
|
||||
message=exc.detail,
|
||||
type=str(exc.status_code),
|
||||
code=exc.status_code,
|
||||
)
|
||||
return ORJSONResponse(content=error.model_dump(), status_code=exc.status_code)
|
||||
|
||||
|
||||
# Custom exception handlers to change validation error status codes
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
|
||||
@@ -4,7 +4,7 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.openai.protocol import ErrorResponse, OpenAIServingRequest
|
||||
@@ -45,7 +45,10 @@ class OpenAIServingBase(ABC):
|
||||
return await self._handle_non_streaming_request(
|
||||
adapted_request, processed_request, raw_request
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
return self.create_error_response(
|
||||
message=e.detail, err_type=str(e.status_code), status_code=e.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in request: {e}")
|
||||
return self.create_error_response(
|
||||
|
||||
@@ -911,6 +911,8 @@ class AbortReq:
|
||||
rid: str = ""
|
||||
# Whether to abort all requests
|
||||
abort_all: bool = False
|
||||
# The finished reason data
|
||||
finished_reason: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -24,6 +24,7 @@ import time
|
||||
from collections import defaultdict, deque
|
||||
from concurrent import futures
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@@ -370,6 +371,7 @@ class Scheduler(
|
||||
self.max_total_num_tokens,
|
||||
self.max_prefill_tokens,
|
||||
self.max_running_requests,
|
||||
self.max_queued_requests,
|
||||
self.max_req_len,
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
@@ -1086,6 +1088,19 @@ class Scheduler(
|
||||
self.return_health_check_ct += 1
|
||||
continue
|
||||
|
||||
# If it is a work request, accept or reject the request based on the request queue size.
|
||||
if is_work_request(recv_req):
|
||||
if len(self.waiting_queue) + 1 > self.max_queued_requests:
|
||||
abort_req = AbortReq(
|
||||
recv_req.rid,
|
||||
finished_reason={
|
||||
"type": "abort",
|
||||
"status_code": HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
"message": "The request queue is full.",
|
||||
},
|
||||
)
|
||||
self.send_to_tokenizer.send_pyobj(abort_req)
|
||||
continue
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
if isinstance(output, RpcReqOutput):
|
||||
@@ -2902,6 +2917,10 @@ def is_health_check_generate_req(recv_req):
|
||||
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
||||
|
||||
|
||||
def is_work_request(recv_req):
|
||||
return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
|
||||
|
||||
|
||||
def _export_static_state(model):
|
||||
return dict(
|
||||
buffers=[
|
||||
|
||||
@@ -766,6 +766,19 @@ class TokenizerManager:
|
||||
):
|
||||
raise ValueError(finish_reason["message"])
|
||||
|
||||
if (
|
||||
finish_reason.get("type") == "abort"
|
||||
and finish_reason.get("status_code")
|
||||
== HTTPStatus.SERVICE_UNAVAILABLE
|
||||
):
|
||||
# This is an abort request initiated by scheduler.
|
||||
# Delete the key to prevent resending abort request to the scheduler and
|
||||
# to ensure aborted request state is cleaned up.
|
||||
del self.rid_to_state[state.obj.rid]
|
||||
raise fastapi.HTTPException(
|
||||
status_code=finish_reason["status_code"],
|
||||
detail=finish_reason["message"],
|
||||
)
|
||||
yield out
|
||||
break
|
||||
|
||||
@@ -1705,8 +1718,15 @@ class TokenizerManager:
|
||||
def _handle_abort_req(self, recv_obj):
|
||||
state = self.rid_to_state[recv_obj.rid]
|
||||
state.finished = True
|
||||
state.out_list.append(
|
||||
{
|
||||
if recv_obj.finished_reason:
|
||||
out = {
|
||||
"meta_info": {
|
||||
"id": recv_obj.rid,
|
||||
"finish_reason": recv_obj.finished_reason,
|
||||
},
|
||||
}
|
||||
else:
|
||||
out = {
|
||||
"text": "",
|
||||
"meta_info": {
|
||||
"id": recv_obj.rid,
|
||||
@@ -1718,7 +1738,7 @@ class TokenizerManager:
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
)
|
||||
state.out_list.append(out)
|
||||
state.event.set()
|
||||
|
||||
def _handle_open_session_req_output(self, recv_obj):
|
||||
@@ -1910,8 +1930,10 @@ class _Communicator(Generic[T]):
|
||||
#
|
||||
# | entrypoint | is_streaming | status | abort engine | cancel asyncio task | rid_to_state |
|
||||
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
|
||||
# | http | yes | validation | background task | fast api | del in _handle_abort_req |
|
||||
# | http | yes | waiting queue | background task | fast api | del in _handle_abort_req |
|
||||
# | http | yes | running | background task | fast api | del in _handle_batch_output |
|
||||
# | http | no | validation | http exception | http exception | del in _handle_abort_req |
|
||||
# | http | no | waiting queue | type 1 | type 1 exception | del in _handle_abort_req |
|
||||
# | http | no | running | type 3 | type 3 exception | del in _handle_batch_output |
|
||||
#
|
||||
|
||||
@@ -130,6 +130,10 @@ class TpModelWorker:
|
||||
self.model_runner.req_to_token_pool.size,
|
||||
)
|
||||
assert self.max_running_requests > 0, "max_running_request is zero"
|
||||
self.max_queued_requests = server_args.max_queued_requests
|
||||
assert (
|
||||
self.max_running_requests > 0
|
||||
), "max_queued_requests is zero. We need to be at least 1 to schedule a request."
|
||||
self.max_req_len = min(
|
||||
self.model_config.context_len - 1,
|
||||
self.max_total_num_tokens - 1,
|
||||
@@ -165,6 +169,7 @@ class TpModelWorker:
|
||||
self.max_total_num_tokens,
|
||||
self.max_prefill_tokens,
|
||||
self.max_running_requests,
|
||||
self.max_queued_requests,
|
||||
self.max_req_len,
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
|
||||
@@ -19,6 +19,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
@@ -74,6 +75,7 @@ class ServerArgs:
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
max_running_requests: Optional[int] = None
|
||||
max_queued_requests: Optional[int] = sys.maxsize
|
||||
max_total_tokens: Optional[int] = None
|
||||
chunked_prefill_size: Optional[int] = None
|
||||
max_prefill_tokens: int = 16384
|
||||
@@ -805,6 +807,12 @@ class ServerArgs:
|
||||
default=ServerArgs.max_running_requests,
|
||||
help="The maximum number of running requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-queued-requests",
|
||||
type=int,
|
||||
default=ServerArgs.max_queued_requests,
|
||||
help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-total-tokens",
|
||||
type=int,
|
||||
|
||||
@@ -19,6 +19,7 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
||||
raise
|
||||
|
||||
|
||||
def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
|
||||
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
|
||||
|
||||
def generate():
|
||||
prompt = """
|
||||
System: You are a helpful assistant.
|
||||
User: What is the capital of France?
|
||||
Assistant: The capital of France is
|
||||
"""
|
||||
response = requests.post(
|
||||
f"{base_url}/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 50,
|
||||
},
|
||||
},
|
||||
)
|
||||
return response.status_code
|
||||
|
||||
return [generate() for _ in range(num_requests)]
|
||||
|
||||
|
||||
async def send_concurrent_generate_requests(
|
||||
base_url: str, num_requests: int
|
||||
) -> List[str]:
|
||||
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
|
||||
|
||||
async def async_generate():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
prompt = """
|
||||
System: You are a helpful assistant.
|
||||
User: What is the capital of France?
|
||||
Assistant: The capital of France is
|
||||
"""
|
||||
async with session.post(
|
||||
f"{base_url}/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 50,
|
||||
},
|
||||
},
|
||||
) as response:
|
||||
return response.status
|
||||
|
||||
tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
class CustomTestCase(unittest.TestCase):
|
||||
def _callTestMethod(self, method):
|
||||
max_retry = int(
|
||||
|
||||
Reference in New Issue
Block a user