[Lint]Style: Convert example to ruff format (#5863)
### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
@@ -18,6 +18,7 @@ All the arguments that can be set by users are:
|
||||
7. `--vllm-start-port`: Starting port of vLLM serving instances, default 9000
|
||||
|
||||
An example of running external DP in one single node:
|
||||
|
||||
```(python)
|
||||
cd examples/external_online_dp
|
||||
# running DP4 TP4 in a node with 16 NPUs
|
||||
@@ -25,6 +26,7 @@ python launch_online_dp.py --dp-size 4 --tp-size 4 --dp-size-local 4 --dp-rank-s
|
||||
```
|
||||
|
||||
An example of running external DP in two nodes:
|
||||
|
||||
```(python)
|
||||
cd examples/external_online_dp
|
||||
# running DP4 TP4 in two nodes with 8 NPUs each
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
# ----------------------------------
|
||||
# You need to have at least two vLLM servers running in data parallel.
|
||||
# These can be mock servers or actual vLLM servers.
|
||||
# Note that this proxy also works with only one vLLM server running, but
|
||||
# Note that this proxy also works with only one vLLM server running, but
|
||||
# will fall back to direct request forwarding which is meaningless.
|
||||
#
|
||||
# For testing, you can use the provided mock server:
|
||||
@@ -84,13 +84,12 @@ import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import heapq
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -109,34 +108,29 @@ except ImportError:
|
||||
|
||||
|
||||
class ServerState:
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f'http://{host}:{port}/v1'
|
||||
self.client = httpx.AsyncClient(timeout=None,
|
||||
base_url=self.url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=100000,
|
||||
max_keepalive_connections=100000))
|
||||
self.url = f"http://{host}:{port}/v1"
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=None,
|
||||
base_url=self.url,
|
||||
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
|
||||
)
|
||||
self.active_tokens = 0
|
||||
self.aborted_requests = set() # Track aborted requests
|
||||
|
||||
|
||||
class ProxyState:
|
||||
|
||||
def __init__(self, server_instances):
|
||||
self.dp_servers: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in server_instances
|
||||
]
|
||||
self.dp_servers: list[ServerState] = [ServerState(h, p) for h, p in server_instances]
|
||||
self.req_id_lock = asyncio.Lock()
|
||||
# Removed selection locks - no longer needed for synchronous methods
|
||||
|
||||
# Initialize priority queues for efficient server selection
|
||||
# Each entry is (priority_score, server_index, server_reference)
|
||||
# Lower priority score = higher priority (less loaded)
|
||||
self.lb_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.dp_servers)]
|
||||
self.lb_heap = [(0, i, server) for i, server in enumerate(self.dp_servers)]
|
||||
heapq.heapify(self.lb_heap)
|
||||
|
||||
def _update_server_priority(self, server_idx: int):
|
||||
@@ -144,10 +138,8 @@ class ProxyState:
|
||||
server = self.dp_servers[server_idx]
|
||||
priority = server.active_tokens
|
||||
# Remove old entry and add new one
|
||||
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.lb_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
self.lb_heap = [(p, i, s) for p, i, s in self.lb_heap if i != server_idx]
|
||||
heapq.heappush(self.lb_heap, (priority, server_idx, server)) # type: ignore
|
||||
|
||||
async def next_req_id(self):
|
||||
async with self.req_id_lock:
|
||||
@@ -190,27 +182,15 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--dp-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--dp-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8001])
|
||||
parser.add_argument("--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument("--dp-hosts", type=str, nargs="+", default=["localhost"])
|
||||
parser.add_argument("--dp-ports", type=int, nargs="+", default=[8001])
|
||||
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Base delay (seconds) for exponential backoff retries")
|
||||
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if len(args.dp_hosts) != len(args.dp_ports):
|
||||
raise ValueError(
|
||||
"Number of dp hosts must match number of dp ports")
|
||||
raise ValueError("Number of dp hosts must match number of dp ports")
|
||||
args.server_instances = list(zip(args.dp_hosts, args.dp_ports))
|
||||
return args
|
||||
|
||||
@@ -219,9 +199,7 @@ def parse_args():
|
||||
async def lifespan(app: FastAPI):
|
||||
global proxy_state
|
||||
proxy_state = ProxyState(global_args.server_instances)
|
||||
print(
|
||||
f"Initialized {len(proxy_state.dp_servers)} dp server clients."
|
||||
)
|
||||
print(f"Initialized {len(proxy_state.dp_servers)} dp server clients.")
|
||||
yield
|
||||
for p in proxy_state.dp_servers:
|
||||
await p.client.aclose()
|
||||
@@ -236,14 +214,12 @@ async def listen_for_disconnect(request: Request) -> None:
|
||||
|
||||
|
||||
def with_cancellation(handler_func):
|
||||
|
||||
@functools.wraps(handler_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs["request"]
|
||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if handler_task in done:
|
||||
@@ -256,22 +232,18 @@ def with_cancellation(handler_func):
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
async def stream_service_response_with_retry(
|
||||
client: httpx.AsyncClient,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2,
|
||||
):
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
async with client.stream("POST",
|
||||
endpoint,
|
||||
json=req_data,
|
||||
headers=headers) as response:
|
||||
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
first_chunk_sent = False
|
||||
async for chunk in response.aiter_bytes():
|
||||
@@ -280,53 +252,42 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
return # Success, exit after streaming
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||
raise e
|
||||
except Exception as e:
|
||||
# If any chunk has been sent, do not retry, just log and drop
|
||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
||||
logger.error(
|
||||
f"Streaming to client interrupted after response started: {str(e)}"
|
||||
)
|
||||
if "first_chunk_sent" in locals() and first_chunk_sent:
|
||||
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
|
||||
return
|
||||
else:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
||||
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
||||
raise e
|
||||
|
||||
|
||||
async def _select_instance(api: str, req_data: Any,
|
||||
request_length: int):
|
||||
async def _select_instance(api: str, req_data: Any, request_length: int):
|
||||
# refer to vLLM sampling_params: max_token default value
|
||||
max_tokens = req_data.get("max_tokens", 16)
|
||||
ignore_eos = req_data.get("ignore_eos", False)
|
||||
priority_score = proxy_state.calculate_request_score(request_length,max_tokens=max_tokens, ignore_eos=ignore_eos)
|
||||
priority_score = proxy_state.calculate_request_score(request_length, max_tokens=max_tokens, ignore_eos=ignore_eos)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, max tokens: {max_tokens}, ignore_eos: {ignore_eos}, Priority score: {priority_score}"
|
||||
f"Request length: {request_length}, max tokens: {max_tokens}, "
|
||||
f"ignore_eos: {ignore_eos}, Priority score: {priority_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select dp server based on priority score
|
||||
server_idx = proxy_state.select_server(priority_score)
|
||||
choosen_server = proxy_state.dp_servers[server_idx]
|
||||
logger.debug(f"Choose server {choosen_server.url} to process request {request_id}")
|
||||
return InstanceInfo(request_id=request_id,
|
||||
server_idx=server_idx,
|
||||
priority_score=priority_score,
|
||||
server_state=choosen_server)
|
||||
return InstanceInfo(
|
||||
request_id=request_id, server_idx=server_idx, priority_score=priority_score, server_state=choosen_server
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -342,36 +303,36 @@ async def _handle_completions(api: str, request: Request):
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
instance_info = await _select_instance(api, req_data,
|
||||
request_length)
|
||||
instance_info = await _select_instance(api, req_data, request_length)
|
||||
|
||||
async def generate_stream():
|
||||
nonlocal instance_info
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
instance_info.server_state.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=instance_info.request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
instance_info.server_state.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=instance_info.request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay,
|
||||
):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, the aborted request is: {instance_info.request_id}."
|
||||
f"Error during streaming from server {instance_info.server_state.url}: {str(e)}, "
|
||||
f"the aborted request is: {instance_info.request_id}."
|
||||
)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_server(instance_info.server_idx,
|
||||
instance_info.priority_score)
|
||||
proxy_state.release_server(instance_info.server_idx, instance_info.priority_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in external dp proxy server"
|
||||
f" - {api} endpoint")
|
||||
print(f"Error occurred in external dp proxy server - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
@@ -397,7 +358,7 @@ async def healthcheck():
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
|
||||
@@ -4,52 +4,19 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dp-size",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Data parallel size."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-size-local",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Local data parallel size."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-rank-start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Starting rank for data parallel."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-address",
|
||||
type=str,
|
||||
required=True,
|
||||
help="IP address for data parallel master node."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-rpc-port",
|
||||
type=str,
|
||||
default=12345,
|
||||
help="Port for data parallel master node."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-start-port",
|
||||
type=int,
|
||||
default=9000,
|
||||
help="Starting port for the engine."
|
||||
)
|
||||
parser.add_argument("--dp-size", type=int, required=True, help="Data parallel size.")
|
||||
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size.")
|
||||
parser.add_argument("--dp-size-local", type=int, default=-1, help="Local data parallel size.")
|
||||
parser.add_argument("--dp-rank-start", type=int, default=0, help="Starting rank for data parallel.")
|
||||
parser.add_argument("--dp-address", type=str, required=True, help="IP address for data parallel master node.")
|
||||
parser.add_argument("--dp-rpc-port", type=str, default=12345, help="Port for data parallel master node.")
|
||||
parser.add_argument("--vllm-start-port", type=int, default=9000, help="Starting port for the engine.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
args = parse_args()
|
||||
dp_size = args.dp_size
|
||||
tp_size = args.tp_size
|
||||
@@ -61,6 +28,7 @@ dp_address = args.dp_address
|
||||
dp_rpc_port = args.dp_rpc_port
|
||||
vllm_start_port = args.vllm_start_port
|
||||
|
||||
|
||||
def run_command(visiable_devices, dp_rank, vllm_engine_port):
|
||||
command = [
|
||||
"bash",
|
||||
@@ -75,6 +43,7 @@ def run_command(visiable_devices, dp_rank, vllm_engine_port):
|
||||
]
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
template_path = "./run_dp_template.sh"
|
||||
if not os.path.exists(template_path):
|
||||
@@ -87,11 +56,9 @@ if __name__ == "__main__":
|
||||
dp_rank = dp_rank_start + i
|
||||
vllm_engine_port = vllm_start_port + i
|
||||
visiable_devices = ",".join(str(x) for x in range(i * tp_size, (i + 1) * tp_size))
|
||||
process = multiprocessing.Process(target=run_command,
|
||||
args=(visiable_devices, dp_rank,
|
||||
vllm_engine_port))
|
||||
process = multiprocessing.Process(target=run_command, args=(visiable_devices, dp_rank, vllm_engine_port))
|
||||
processes.append(process)
|
||||
process.start()
|
||||
|
||||
for process in processes:
|
||||
process.join()
|
||||
process.join()
|
||||
|
||||
Reference in New Issue
Block a user