[Lint]Style: Convert example to ruff format (#5863)

### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
SILONG ZENG
2026-01-13 20:46:50 +08:00
committed by GitHub
parent f7b904641e
commit 78d5ce3e01
23 changed files with 678 additions and 1037 deletions

View File

@@ -18,6 +18,7 @@ All the arguments that can be set by users are:
7. `--vllm-start-port`: Starting port of vLLM serving instances, default 9000
An example of running external DP in one single node:
```(python)
cd examples/external_online_dp
# running DP4 TP4 in a node with 16 NPUs
@@ -25,6 +26,7 @@ python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 4 --dp-rank-s
```
An example of running external DP in two nodes:
```(python)
cd examples/external_online_dp
# running DP4 TP4 in two nodes with 8 NPUs each

View File

@@ -23,7 +23,7 @@
# ----------------------------------
# You need to have at least two vLLM servers running in data parallel.
# These can be mock servers or actual vLLM servers.
# Note that this proxy also works with only one vLLM server running, but
# Note that this proxy also works with only one vLLM server running, but
# will fall back to direct request forwarding which is meaningless.
#
# For testing, you can use the provided mock server:
@@ -84,13 +84,12 @@ import argparse
import asyncio
import functools
import heapq
import json
import os
import sys
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List
from typing import Any
import httpx
from fastapi import FastAPI, Request
@@ -109,34 +108,29 @@ except ImportError:
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.url = f"http://{host}:{port}/v1"
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
)
self.active_tokens = 0
self.aborted_requests = set() # Track aborted requests
class ProxyState:
def __init__(self, server_instances):
self.dp_servers: List[ServerState] = [
ServerState(h, p) for h, p in server_instances
]
self.dp_servers: list[ServerState] = [ServerState(h, p) for h, p in server_instances]
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.lb_heap = [(0, i, server)
for i, server in enumerate(self.dp_servers)]
self.lb_heap = [(0, i, server) for i, server in enumerate(self.dp_servers)]
heapq.heapify(self.lb_heap)
def _update_server_priority(self, server_idx: int):
@@ -144,10 +138,8 @@ class ProxyState:
server = self.dp_servers[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap
if i != server_idx]
heapq.heappush(self.lb_heap,
(priority, server_idx, server)) # type: ignore
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap if i != server_idx]
heapq.heappush(self.lb_heap, (priority, server_idx, server)) # type: ignore
async def next_req_id(self):
async with self.req_id_lock:
@@ -190,27 +182,15 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--dp-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--dp-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument("--dp-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--dp-ports", type=int, nargs="+", default=[8001])
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
)
args = parser.parse_args()
if len(args.dp_hosts) != len(args.dp_ports):
raise ValueError(
"Number of dp hosts must match number of dp ports")
raise ValueError("Number of dp hosts must match number of dp ports")
args.server_instances = list(zip(args.dp_hosts, args.dp_ports))
return args
@@ -219,9 +199,7 @@ def parse_args():
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.server_instances)
print(
f"Initialized {len(proxy_state.dp_servers)} dp server clients."
)
print(f"Initialized {len(proxy_state.dp_servers)} dp server clients.")
yield
for p in proxy_state.dp_servers:
await p.client.aclose()
@@ -236,14 +214,12 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
@@ -256,22 +232,18 @@ def with_cancellation(handler_func):
app = FastAPI(lifespan=lifespan)
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async def stream_service_response_with_retry(
client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
@@ -280,53 +252,42 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
if "first_chunk_sent" in locals() and first_chunk_sent:
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
async def _select_instance(api: str, req_data: Any,
request_length: int):
async def _select_instance(api: str, req_data: Any, request_length: int):
# refer to vLLM sampling_params: max_token default value
max_tokens = req_data.get("max_tokens", 16)
ignore_eos = req_data.get("ignore_eos", False)
priority_score = proxy_state.calculate_request_score(request_length,max_tokens=max_tokens, ignore_eos=ignore_eos)
priority_score = proxy_state.calculate_request_score(request_length, max_tokens=max_tokens, ignore_eos=ignore_eos)
logger.debug(
f"Request length: {request_length}, max tokens: {max_tokens}, ignore_eos: {ignore_eos}, Priority score: {priority_score}"
f"Request length: {request_length}, max tokens: {max_tokens}, "
f"ignore_eos: {ignore_eos}, Priority score: {priority_score}"
)
request_id = await proxy_state.next_req_id()
# Select dp server based on priority score
server_idx = proxy_state.select_server(priority_score)
choosen_server = proxy_state.dp_servers[server_idx]
logger.debug(f"Choose server {choosen_server.url} to process request {request_id}")
return InstanceInfo(request_id=request_id,
server_idx=server_idx,
priority_score=priority_score,
server_state=choosen_server)
return InstanceInfo(
request_id=request_id, server_idx=server_idx, priority_score=priority_score, server_state=choosen_server
)
@dataclass
@@ -342,36 +303,36 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
instance_info = await _select_instance(api, req_data,
request_length)
instance_info = await _select_instance(api, req_data, request_length)
async def generate_stream():
nonlocal instance_info
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
instance_info.server_state.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
instance_info.server_state.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, the aborted request is: {instance_info.request_id}."
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, "
f"the aborted request is: {instance_info.request_id}."
)
# After streaming done, release tokens
proxy_state.release_server(instance_info.server_idx,
instance_info.priority_score)
proxy_state.release_server(instance_info.server_idx, instance_info.priority_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in external dp proxy server"
f" - {api} endpoint")
print(f"Error occurred in external dp proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -397,7 +358,7 @@ async def healthcheck():
}
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn

View File

@@ -4,52 +4,19 @@ import os
import subprocess
import sys
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dp-size",
type=int,
required=True,
help="Data parallel size."
)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size."
)
parser.add_argument(
"--dp-size-local",
type=int,
default=-1,
help="Local data parallel size."
)
parser.add_argument(
"--dp-rank-start",
type=int,
default=0,
help="Starting rank for data parallel."
)
parser.add_argument(
"--dp-address",
type=str,
required=True,
help="IP address for data parallel master node."
)
parser.add_argument(
"--dp-rpc-port",
type=str,
default=12345,
help="Port for data parallel master node."
)
parser.add_argument(
"--vllm-start-port",
type=int,
default=9000,
help="Starting port for the engine."
)
parser.add_argument("--dp-size", type=int, required=True, help="Data parallel size.")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size.")
parser.add_argument("--dp-size-local", type=int, default=-1, help="Local data parallel size.")
parser.add_argument("--dp-rank-start", type=int, default=0, help="Starting rank for data parallel.")
parser.add_argument("--dp-address", type=str, required=True, help="IP address for data parallel master node.")
parser.add_argument("--dp-rpc-port", type=str, default=12345, help="Port for data parallel master node.")
parser.add_argument("--vllm-start-port", type=int, default=9000, help="Starting port for the engine.")
return parser.parse_args()
args = parse_args()
dp_size = args.dp_size
tp_size = args.tp_size
@@ -61,6 +28,7 @@ dp_address = args.dp_address
dp_rpc_port = args.dp_rpc_port
vllm_start_port = args.vllm_start_port
def run_command(visiable_devices, dp_rank, vllm_engine_port):
command = [
"bash",
@@ -75,6 +43,7 @@ def run_command(visiable_devices, dp_rank, vllm_engine_port):
]
subprocess.run(command, check=True)
if __name__ == "__main__":
template_path = "./run_dp_template.sh"
if not os.path.exists(template_path):
@@ -87,11 +56,9 @@ if __name__ == "__main__":
dp_rank = dp_rank_start + i
vllm_engine_port = vllm_start_port + i
visiable_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
process = multiprocessing.Process(target=run_command,
args=(visiable_devices, dp_rank,
vllm_engine_port))
process = multiprocessing.Process(target=run_command, args=(visiable_devices, dp_rank, vllm_engine_port))
processes.append(process)
process.start()
for process in processes:
process.join()
process.join()