feat: throttle requests at scheduler based on --max_queued_requests (#7565)

This commit is contained in:
harrisonlimh
2025-07-28 07:32:33 -07:00
committed by GitHub
parent b582159246
commit 747dd45077
10 changed files with 218 additions and 6 deletions

View File

@@ -19,6 +19,7 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Awaitable, Callable, List, Optional, Tuple
import aiohttp
import numpy as np
import requests
import torch
@@ -1303,6 +1304,58 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
raise
def send_generate_requests(base_url: str, num_requests: int) -> List[str]:
"""Sends generate request serially and returns status codes. Max concurrency is 1."""
def generate():
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
response = requests.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
},
)
return response.status_code
return [generate() for _ in range(num_requests)]
async def send_concurrent_generate_requests(
base_url: str, num_requests: int
) -> List[str]:
"""Sends generate request concurrently and returns status codes. Max concurrency is num_requests."""
async def async_generate():
async with aiohttp.ClientSession() as session:
prompt = """
System: You are a helpful assistant.
User: What is the capital of France?
Assistant: The capital of France is
"""
async with session.post(
f"{base_url}/generate",
json={
"text": prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 50,
},
},
) as response:
return response.status
tasks = [asyncio.create_task(async_generate()) for _ in range(num_requests)]
return await asyncio.gather(*tasks)
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
max_retry = int(