[Lint]Style: Convert example to ruff format (#5863)

### What this PR does / why we need it?
This PR fixes linting issues in the `example/` to align with the
project's Ruff configuration.

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

Signed-off-by: root <root@LAPTOP-VQKDDVMG.localdomain>
Co-authored-by: root <root@LAPTOP-VQKDDVMG.localdomain>
This commit is contained in:
SILONG ZENG
2026-01-13 20:46:50 +08:00
committed by GitHub
parent f7b904641e
commit 78d5ce3e01
23 changed files with 678 additions and 1037 deletions

View File

@@ -91,10 +91,8 @@ import heapq
import ipaddress
import os
import sys
import threading
import uuid
from contextlib import asynccontextmanager
from typing import List
import httpx
from fastapi import FastAPI, Request
@@ -106,28 +104,28 @@ logger = init_logger(__name__)
# Add uvloop for faster event loop if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.url = f"http://{host}:{port}/v1"
try:
ip = ipaddress.ip_address(self.host)
if isinstance(ip, ipaddress.IPv6Address):
self.url = f'http://[{host}]:{port}/v1'
self.url = f"http://[{host}]:{port}/v1"
except Exception:
pass
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
)
self.active_tokens = 0
self.active_kv_cache = 0 # Only for prefiller
self.active_requests = 0 # Number of active requests
@@ -136,14 +134,9 @@ class ServerState:
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
self.decoders: List[ServerState] = [
ServerState(h, p) for h, p in decoder_instances
]
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
self.req_to_prefiller = {}
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
@@ -151,10 +144,8 @@ class ProxyState:
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.prefiller_heap = [(0, i, server)
for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server)
for i, server in enumerate(self.decoders)]
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
self.req_id_future = {}
@@ -166,23 +157,18 @@ class ProxyState:
# Priority based on active_tokens and active_kv_cache
priority = server.active_tokens + server.active_kv_cache * 0.3
# Remove old entry and add new one
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
if i != server_idx]
heapq.heappush(self.prefiller_heap,
(priority, server_idx, server)) # type: ignore
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
def _update_decoder_priority(self, server_idx: int):
"""Update the priority of a decoder server in the heap."""
server = self.decoders[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
if i != server_idx]
heapq.heappush(self.decoder_heap,
(priority, server_idx, server)) # type: ignore
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
def abort_prefiller_request(self, server_idx: int,
request_id): # Changed to synchronous
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
"""
Mark a request as aborted. This will helps to release kv cache in
prefiller node.
@@ -190,8 +176,7 @@ class ProxyState:
# No lock needed - atomic operation
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(
self, server_idx: int): # Changed to synchronous
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
"""
Get the set of aborted requests and clear it.
This is used to release kv cache in prefiller node.
@@ -272,37 +257,20 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
)
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
raise ValueError("Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
@@ -310,11 +278,8 @@ def parse_args():
@asynccontextmanager
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.prefiller_instances,
global_args.decoder_instances)
print(
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
)
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
yield
for p in proxy_state.prefillers:
await p.client.aclose()
@@ -331,14 +296,12 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
@@ -351,15 +314,16 @@ def with_cancellation(handler_func):
app = FastAPI(lifespan=lifespan)
async def send_request_to_service(client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
async def send_request_to_service(
client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
req_data = req_data.copy()
req_data["stream"] = False
req_data["max_tokens"] = 1
@@ -368,49 +332,38 @@ async def send_request_to_service(client: httpx.AsyncClient,
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
last_exc = None
for attempt in range(1, max_retries + 1):
try:
response = await client.post(endpoint,
json=req_data,
headers=headers)
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
if request_id in proxy_state.req_id_future:
result_future = proxy_state.req_id_future[request_id]
result_future.set_result(response.json()["kv_transfer_params"])
return
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
last_exc = e
if attempt < max_retries:
await asyncio.sleep(base_delay * (2**(attempt - 1)))
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for {endpoint}.")
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
raise last_exc
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async def stream_service_response_with_retry(
client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
@@ -419,32 +372,22 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
if "first_chunk_sent" in locals() and first_chunk_sent:
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
@@ -469,15 +412,11 @@ async def _handle_completions(api: str, request: Request):
request_length = len(req_body)
request_id = await proxy_state.next_req_id()
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
api)
req_data['kv_transfer_params'] = {
"do_remote_decode":
False,
"do_remote_prefill":
True,
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api)
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver",
}
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
@@ -494,28 +433,30 @@ async def _handle_completions(api: str, request: Request):
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {decoder.url}: {str(e)} "
f"the aborted request {request_id} will be routing to the target "
"prefiller when new request is ready to dispatch to it"
)
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -538,7 +479,7 @@ async def healthcheck():
return {
"status": "ok",
"prefill_instances": len(proxy_state.prefillers),
"decode_instances": len(proxy_state.decoders)
"decode_instances": len(proxy_state.decoders),
}
@@ -553,25 +494,24 @@ async def metaserver(request: Request):
request_id = get_origin_request_id(api, request_id)
req_data["kv_transfer_params"] = kv_transfer_params
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
# Send request to prefiller
response = await send_request_to_service(
await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
base_delay=global_args.retry_delay,
)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
proxy_state.release_prefiller_kv(prefiller_idx,prefiller_score)
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}")
@@ -579,8 +519,9 @@ async def metaserver(request: Request):
proxy_state.release_prefiller_kv(prefiller_idx, prefiller_score)
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -125,7 +125,7 @@ import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List, Tuple, Dict
from typing import Any
import httpx
from fastapi import FastAPI, Request
@@ -150,22 +150,21 @@ class InstanceType:
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.url = f"http://{host}:{port}/v1"
try:
ip = ipaddress.ip_address(self.host)
if isinstance(ip, ipaddress.IPv6Address):
self.url = f'http://[{host}]:{port}/v1'
self.url = f"http://[{host}]:{port}/v1"
except Exception:
pass
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.client = httpx.AsyncClient(
timeout=None,
base_url=self.url,
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
)
self.active_tokens = 0
self.active_kv_cache = 0 # Only for prefiller
self.active_requests = 0 # Number of active requests
@@ -186,16 +185,11 @@ class ServerState:
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.node_listener = NodeListener(self)
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
self.decoders: List[ServerState] = [
ServerState(h, p) for h, p in decoder_instances
]
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
self.req_to_prefiller = {}
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
@@ -203,10 +197,8 @@ class ProxyState:
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.prefiller_heap = [(0, i, server)
for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server)
for i, server in enumerate(self.decoders)]
self.prefiller_heap = [(0, i, server) for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server) for i, server in enumerate(self.decoders)]
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
@@ -216,23 +208,18 @@ class ProxyState:
# Priority based on active_tokens and active_kv_cache
priority = server.active_tokens + server.active_kv_cache * 0.3
# Remove old entry and add new one
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
if i != server_idx]
heapq.heappush(self.prefiller_heap,
(priority, server_idx, server)) # type: ignore
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
heapq.heappush(self.prefiller_heap, (priority, server_idx, server)) # type: ignore
def _update_decoder_priority(self, server_idx: int):
"""Update the priority of a decoder server in the heap."""
server = self.decoders[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
if i != server_idx]
heapq.heappush(self.decoder_heap,
(priority, server_idx, server)) # type: ignore
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
heapq.heappush(self.decoder_heap, (priority, server_idx, server)) # type: ignore
def abort_prefiller_request(self, server_idx: int,
request_id): # Changed to synchronous
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
"""
Mark a request as aborted. This will helps to release kv cache in
prefiller node.
@@ -240,8 +227,7 @@ class ProxyState:
# No lock needed - atomic operation
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(
self, server_idx: int): # Changed to synchronous
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
"""
Get the set of aborted requests and clear it.
This is used to release kv cache in prefiller node.
@@ -314,9 +300,7 @@ class ProxyState:
def calculate_decode_scores(self, request_length: int) -> float:
return request_length
async def add_instances(
self, instance_type: str, instances: List[ServerState]
) -> Tuple[List[str], List[str]]:
async def add_instances(self, instance_type: str, instances: list[ServerState]) -> tuple[list[str], list[str]]:
added_nodes, waiting_nodes = [], []
for server in instances:
is_valid = await self.node_listener.check_instance_status(server.client)
@@ -332,7 +316,7 @@ class ProxyState:
waiting_nodes.append(node)
return added_nodes, waiting_nodes
def add_prefillers(self, instances: List[ServerState]) -> None:
def add_prefillers(self, instances: list[ServerState]) -> None:
num_prefillers = len(self.prefillers)
for idx, server in enumerate(instances):
if server not in self.prefillers:
@@ -341,7 +325,7 @@ class ProxyState:
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
self.print_status(f"Add prefiller instances: {instances}.")
def add_decoders(self, instances: List[ServerState]) -> None:
def add_decoders(self, instances: list[ServerState]) -> None:
num_decoders = len(self.decoders)
for idx, server in enumerate(instances):
if server not in self.decoders:
@@ -350,7 +334,7 @@ class ProxyState:
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
self.print_status(f"Add decoder instances: {instances}.")
def remove_prefillers(self, instances: List[ServerState]) -> None:
def remove_prefillers(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
prefiller_heap_copy = self.prefiller_heap.copy()
@@ -367,7 +351,7 @@ class ProxyState:
heapq.heapify(self.prefiller_heap)
self.print_status(f"Remove prefiller instances: {instances}.")
def remove_decoders(self, instances: List[ServerState]) -> None:
def remove_decoders(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
decoder_heap_copy = self.decoder_heap.copy()
@@ -387,7 +371,7 @@ class ProxyState:
def print_status(self, msg: str) -> None:
status = {
"prefill_instances": [str(server) for server in self.prefillers],
"decode_instances": [str(server) for server in self.decoders]
"decode_instances": [str(server) for server in self.decoders],
}
print(f"{msg} Status: {status}")
@@ -398,7 +382,7 @@ proxy_state = None
class NodeListener:
def __init__(self, proxy):
self.proxy_state = proxy
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
self.waiting_nodes: dict[str, tuple[str, Any, int]] = {}
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
self.listening_thread.start()
@@ -424,9 +408,7 @@ class NodeListener:
@staticmethod
async def check_instance_status(client: httpx.AsyncClient) -> bool:
endpoint = "/models"
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
try:
response = await client.get(endpoint, headers=headers)
response.raise_for_status()
@@ -439,46 +421,29 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
parser.add_argument("--max-waiting-retries",
type=int,
default=3,
help="Maximum number of retries for waiting nodes to be started")
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
)
parser.add_argument(
"--max-waiting-retries", type=int, default=3, help="Maximum number of retries for waiting nodes to be started"
)
parser.add_argument(
"--waiting-retry-interval",
type=float,
default=10,
help="Check interval (seconds) for waiting nodes to be started")
help="Check interval (seconds) for waiting nodes to be started",
)
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
raise ValueError("Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
@@ -486,11 +451,8 @@ def parse_args():
@asynccontextmanager
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.prefiller_instances,
global_args.decoder_instances)
print(
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
)
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
yield
for p in proxy_state.prefillers:
await p.client.aclose()
@@ -507,14 +469,12 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
@@ -527,17 +487,18 @@ def with_cancellation(handler_func):
app = FastAPI(lifespan=lifespan)
async def send_request_to_service(client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
async def send_request_to_service(
client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(prefiller_id)
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
req_data["kv_transfer_params"] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
@@ -553,46 +514,35 @@ async def send_request_to_service(client: httpx.AsyncClient,
req_data["max_completion_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
last_exc = None
for attempt in range(1, max_retries + 1):
try:
response = await client.post(endpoint,
json=req_data,
headers=headers)
response = await client.post(endpoint, json=req_data, headers=headers)
response.raise_for_status()
return response
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
last_exc = e
if attempt < max_retries:
await asyncio.sleep(base_delay * (2**(attempt - 1)))
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for {endpoint}.")
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
raise last_exc
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
async def stream_service_response_with_retry(
client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2,
):
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
@@ -601,41 +551,28 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
if "first_chunk_sent" in locals() and first_chunk_sent:
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
raise e
async def _handle_select_instance(api: str, req_data: Any,
request_length: int):
async def _handle_select_instance(api: str, req_data: Any, request_length: int):
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
@@ -648,10 +585,11 @@ async def _handle_select_instance(api: str, req_data: Any,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
base_delay=global_args.retry_delay,
)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
kv_transfer_params = response_json.get("kv_transfer_params", {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Select decoder
@@ -661,13 +599,15 @@ async def _handle_select_instance(api: str, req_data: Any,
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)
return InstanceInfo(
request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score,
)
@dataclass
@@ -686,8 +626,7 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
instance_info = await _handle_select_instance(api, req_data,
request_length)
instance_info = await _handle_select_instance(api, req_data, request_length)
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
@@ -713,34 +652,31 @@ async def _handle_completions(api: str, request: Request):
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
released_kv = True
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(
f"Skipping chunk: {chunk}")
logger.debug(f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
chunk_str = chunk_str[len("data: ") :]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(
f"Skipping chunk: {chunk_str}")
logger.debug(f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
@@ -751,63 +687,52 @@ async def _handle_completions(api: str, request: Request):
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
content = delta.get("content") or message.get("content") or choice.get("text") or ""
generated_token += content
stop_reason = choice.get(
"stop_reason")
stop_reason = choice.get("stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
completion_tokens = (
(completion_tokens + 1)
if stream_flag
else (completion_tokens + usage.get("completion_tokens"))
)
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
messages[0]["content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
req_data["prompt"] = origin_prompt + generated_token
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choice["message"][
"content"] = generated_token
choice["message"]["content"] = generated_token
else:
choice["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} "
f"the aborted request {instance_info.request_id} will be routing to the target "
"prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)
proxy_state.abort_prefiller_request(instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(instance_info.decoder_idx,
instance_info.decoder_score)
proxy_state.release_decoder(instance_info.decoder_idx, instance_info.decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@@ -821,20 +746,21 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
if isinstance(instances, str):
instances = [instances]
instances = trans_instances(instances)
all_msg = f"{adjust_mode} {instance_type} instances: " \
f"{[str(server) for server in instances]}."
all_msg = f"{adjust_mode} {instance_type} instances: {[str(server) for server in instances]}."
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
return {"error": f"Instance type {instance_type} is not supported. "
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
return {
"error": f"Instance type {instance_type} is not supported. "
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."
}
if adjust_mode == "add":
added_nodes, waiting_nodes = await proxy_state.add_instances(
instance_type, instances
)
added_nodes, waiting_nodes = await proxy_state.add_instances(instance_type, instances)
if waiting_nodes:
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
f"Instances {waiting_nodes} are waiting to be added."
all_msg = (
f"{adjust_mode} {instance_type} instances: {added_nodes}. "
f"Instances {waiting_nodes} are waiting to be added."
)
elif adjust_mode == "remove":
if instance_type == InstanceType.PREFILL:
proxy_state.remove_prefillers(instances)
@@ -843,14 +769,14 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
return {
"message": all_msg,
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders],
}
except Exception as e:
logger.error(f"Failed to {adjust_mode} instances: {e}")
raise e
def trans_instances(instances: List[str]) -> List[ServerState]:
def trans_instances(instances: list[str]) -> list[ServerState]:
server_list = []
for instance in instances:
h, p = instance.split(":")
@@ -875,7 +801,7 @@ async def healthcheck():
return {
"status": "ok",
"prefill_instances": len(proxy_state.prefillers),
"decode_instances": len(proxy_state.decoders)
"decode_instances": len(proxy_state.decoders),
}
@@ -889,7 +815,7 @@ async def handle_remove_instances(request: Request):
return await _handle_adjust_instances("remove", request)
if __name__ == '__main__':
if __name__ == "__main__":
global global_args
global_args = parse_args()
import uvicorn

View File

@@ -2,17 +2,17 @@
## Environmental Dependencies
* Software:
* Python >= 3.10, < 3.12
* CANN == 8.3.rc2
* PyTorch == 2.8.0, torch-npu == 2.8.0
* vLLM (same version as vllm-ascend)
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
* Software:
* Python >= 3.10, < 3.12
* CANN == 8.3.rc2
* PyTorch == 2.8.0, torch-npu == 2.8.0
* vLLM (same version as vllm-ascend)
* mooncake-transfer-engine reference documentation: https://github.com/kvcache-ai/Mooncake/blob/main/doc/zh/ascend_transport.md
The vllm version must be the same as the main branch of vllm-ascend, for example, 2025/07/30. The version is
* vllm: v0.10.1
* vllm-ascend: v0.10.1rc1
* vllm: v0.10.1
* vllm-ascend: v0.10.1rc1
## run
@@ -84,7 +84,6 @@ Set `GLOO_SOCKET_IFNAME`, `TP_SOCKET_IFNAME`, and `HCCL_SOCKET_IFNAME` to the co
`--gpu-memory-utilization`: Percentage of video memory occupied by the card<br>
`--kv-transfer-config`: follow kv_connector, kv_connector_module_path: mooncakeconnect, kv_buffer_device, and run on the NPU card. For kv_role, set kv_producer to the p node, kv_consumer to the d node, kv_parallel_size to 1, and kv_port to the port used by the node. For the p node, set engine_id and kv_rank to 0 and for the d node to 1. Configure the distributed parallel policy for the p and d nodes in the kv_connector_extra_config file based on --tensor-parallel-size and --data-parallel-size.<br>
### 2. Run `decode` Node
```
@@ -151,7 +150,6 @@ python load_balance_proxy_server_example.py --host localhost --prefiller-hosts h
`--decoder-hosts`: Set this parameter to the IP addresses of all d nodes. In the xpyd scenario, add the IP addresses to the end of this configuration item and leave a blank space between the IP addresses.<br>
`--decoder-ports`: Set this parameter to the port number of all d nodes, which is the configuration of the port number for the vllm to start the service in step 4. Set port to the end of the configuration, and leave a blank space between port and port. The sequence must be one-to-one mapping to the IP address of --decoder-hosts.<br>
### 4. Run Inference
Set the IP address in the inference file to the actual IP address. Set the model variable to the path of the model. Ensure that the path is the same as that in the shell script.
@@ -162,4 +160,4 @@ curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json"
"prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?",
"max_tokens": 256
}'
```
```