Set proper 'text/event-stream; charset=utf-8' media type for streaming
requests instead of hardcoded 'application/json'
### What this PR does / why we need it?
This PR fixes an issue in the disaggregated prefill proxy server where
streaming requests (`"stream": true`) were always returned with a
hardcoded `Content-Type: application/json`, even when the backend vLLM
servers correctly returned Server-Sent Events (SSE) with `Content-Type:
text/event-stream; charset=utf-8`.
Specifically, the proxy used `StreamingResponse` with a fixed
`media_type` of `application/json`, which caused FastAPI to override the
response headers and break proper SSE semantics. As a result, clients
(e.g. `curl -i`, EventSource, or OpenAI-compatible SDKs) could not
reliably receive token-by-token streaming output.
In addition, this incorrect response type causes compatibility issues
with benchmarking and load-testing tools such as **EvalScope**. When
streaming is enabled, these tools expect SSE-formatted responses to
correctly parse token usage information. With the incorrect
`application/json` content type, EvalScope fails to parse the response
and reports errors similar to:`2025-12-15 09:27:56 - evalscope - ERROR:
Failed to parse usage from response: list index out of range. Response:
[]`
This PR updates the proxy to:
- Detect whether the incoming request is a streaming request
(`stream=true`)
- Use `text/event-stream; charset=utf-8` for streaming responses
- Preserve `application/json` for non-streaming responses
This aligns the proxy behavior with native vLLM prefill/decoder servers
and the OpenAI-compatible streaming API contract.
Fixes incorrect streaming response headers that prevented proper
real-time token delivery.
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
This change was tested manually using a disaggregated prefill + decode
setup
with the proxy server.
### Test Steps
1. Start prefiller and decoder vLLM servers:
```bash
vllm serve --host 0.0.0.0 --port 8001 ...
vllm serve --host 0.0.0.0 --port 8002 ...
```
2. Start the proxy server:
```bash
python load_balance_proxy_server_example.py \
--host 127.0.0.1 --port 8000 \
--prefiller-hosts 127.0.0.1 --prefiller-ports 8001 \
--decoder-hosts 127.0.0.1 --decoder-ports 8002
```
3. Send a streaming completion request through the proxy:
```bash
curl -i -X POST http://127.0.0.1:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "test",
"prompt": "hello",
"max_tokens": 3,
"stream": true
}'
```
4. Verify the following:
- The response header is Content-Type: text/event-stream; charset=utf-8
- Tokens are streamed incrementally as SSE data: events
- Non-streaming requests still return application/json
No automated tests were added because this change affects an example
proxy
server and is limited to HTTP response headers. The behavior is directly
verifiable using standard SSE-compatible clients.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Co-authored-by: zrj026 <zhangrunjiang026@gmail.com>
918 lines
36 KiB
Python
918 lines
36 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Tutorial: Using the Load Balance Proxy Server Example
|
|
#
|
|
# This proxy server is designed to distribute requests between multiple
|
|
# "prefiller" and "decoder" backend servers for large language model inference.
|
|
# It is useful for scaling out inference workloads and balancing load across
|
|
# multiple backend instances.
|
|
#
|
|
# Features:
|
|
# - Load balances requests to multiple prefiller and decoder servers.
|
|
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
|
|
# - Streams responses from backend servers to clients.
|
|
#
|
|
# Prerequisites:
|
|
# - Python 3.10+
|
|
# - Install dependencies:
|
|
# pip install fastapi<0.124.0 httpx uvicorn vllm
|
|
#
|
|
# Step 1: Start Your Backend Servers
|
|
# ----------------------------------
|
|
# You need to have at least one prefiller and one decoder backend running.
|
|
# These can be mock servers or actual vLLM servers.
|
|
#
|
|
# For testing, you can use the provided mock server:
|
|
#
|
|
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
|
|
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
|
|
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
|
|
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
|
|
#
|
|
# Step 2: Start the Proxy Server
|
|
# ------------------------------
|
|
# Run the proxy server, specifying the host/port for each prefiller and decoder:
|
|
#
|
|
# python load_balance_proxy_server_example.py \
|
|
# --host 0.0.0.0 --port 9000 \
|
|
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
|
|
# --prefiller-ports 8100 8101 \
|
|
# --decoder-hosts 127.0.0.1 127.0.0.1 \
|
|
# --decoder-ports 8200 8201
|
|
#
|
|
# This will start the proxy on port 9000, load balancing between two prefiller
|
|
# and two decoder servers.
|
|
#
|
|
# Step 3: Send a Request to the Proxy
|
|
# -----------------------------------
|
|
# You can now send OpenAI-compatible requests to the proxy. For example:
|
|
#
|
|
# curl -X POST http://localhost:9000/v1/completions \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "model": "your-model",
|
|
# "prompt": "The quick brown fox jumps over the lazy dog",
|
|
# "max_tokens": 16
|
|
# }'
|
|
#
|
|
# Or for chat completions:
|
|
#
|
|
# curl -X POST http://localhost:9000/v1/chat/completions \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "model": "your-model",
|
|
# "messages": [{"role": "user", "content": "Hello!"}],
|
|
# "max_tokens": 16
|
|
# }'
|
|
#
|
|
# Step 4: Health Check
|
|
# --------------------
|
|
# To check if the proxy is running and see how many backend instances are
|
|
# connected, use:
|
|
#
|
|
# curl http://localhost:9000/healthcheck
|
|
#
|
|
# This will return a JSON object with the status and the number of prefiller
|
|
# and decoder instances.
|
|
#
|
|
# Step 5: Add or Remove Prefiller or Decoder Instances (Optional)
|
|
# ---------------------------------------------------------------
|
|
# You can add or remove prefiller or decoder instances after the proxy is started.
|
|
# For example, add 2 prefiller instances:
|
|
#
|
|
# curl -X POST http://localhost:9000/instances/add \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "type": "prefill",
|
|
# "instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
|
|
# }'
|
|
#
|
|
# or remove 1 decoder instance:
|
|
#
|
|
# curl -X POST http://localhost:9000/instances/remove \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "type": "decode",
|
|
# "instances": "127.0.0.1:8201"
|
|
# }'
|
|
#
|
|
# This will return a JSON object with the adding or removing info
|
|
# and the current prefiller and decoder instances.
|
|
#
|
|
# When adding instances, if the instances are not started,
|
|
# the proxy will wait and try until the instances to be started
|
|
# or exceeding the number of attempts
|
|
#
|
|
# Notes:
|
|
# - You can scale the number of prefiller and decoder servers as needed.
|
|
# - The proxy will round-robin requests to balance load.
|
|
# - For production, ensure your backend servers are robust and secure.
|
|
#
|
|
# For more details, see the code and comments in this file.
|
|
|
|
import argparse
|
|
import asyncio
|
|
import functools
|
|
import heapq
|
|
import ipaddress
|
|
import json
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import StreamingResponse
|
|
|
|
try:
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
except ImportError:
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Add uvloop for faster event loop if available
|
|
try:
|
|
import uvloop
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class InstanceType:
|
|
PREFILL: str = "prefill"
|
|
DECODE: str = "decode"
|
|
|
|
|
|
TAINT_PRIORITY = 1e15
|
|
|
|
|
|
class ServerState:
|
|
def __init__(self, host, port):
|
|
self.host = host
|
|
self.port = port
|
|
self.url = f"http://{host}:{port}/v1"
|
|
try:
|
|
ip = ipaddress.ip_address(self.host)
|
|
if isinstance(ip, ipaddress.IPv6Address):
|
|
self.url = f"http://[{host}]:{port}/v1"
|
|
except Exception:
|
|
pass
|
|
self.client = httpx.AsyncClient(
|
|
timeout=None,
|
|
base_url=self.url,
|
|
limits=httpx.Limits(max_connections=100000, max_keepalive_connections=100000),
|
|
)
|
|
self.active_tokens = 0
|
|
self.active_kv_cache = 0 # Only for prefiller
|
|
self.active_requests = 0 # Number of active requests
|
|
self.aborted_requests = set() # Track aborted requests
|
|
# Removed individual server lock - will use global locks instead
|
|
|
|
def __eq__(self, other):
|
|
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
|
other_host = other.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
|
return self_host == other_host and str(self.port) == str(other.port)
|
|
|
|
def __hash__(self):
|
|
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
|
return hash((self_host, str(self.port)))
|
|
|
|
def __repr__(self):
|
|
return f"{self.host}:{self.port}"
|
|
|
|
|
|
class ProxyState:
|
|
def __init__(self, prefiller_instances, decoder_instances):
|
|
self.request_num = 0
|
|
self.tainted_prefillers: list[ServerState] = []
|
|
self.tainted_decoders: list[ServerState] = []
|
|
self.node_listener = NodeListener(self)
|
|
|
|
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
|
|
self.decoders: list[ServerState] = [ServerState(h, p) for h, p in decoder_instances]
|
|
self.req_to_prefiller = {}
|
|
self.req_id_lock = asyncio.Lock()
|
|
# Removed selection locks - no longer needed for synchronous methods
|
|
|
|
# Initialize priority queues for efficient server selection
|
|
# Each entry is (priority_score, server_index, server_reference)
|
|
# Lower priority score = higher priority (less loaded)
|
|
self.prefiller_heap = [(0.0, i, server) for i, server in enumerate(self.prefillers)]
|
|
self.decoder_heap = [(0.0, i, server) for i, server in enumerate(self.decoders)]
|
|
heapq.heapify(self.prefiller_heap)
|
|
heapq.heapify(self.decoder_heap)
|
|
|
|
def _update_prefiller_priority(self, server_idx: int):
|
|
"""Update the priority of a prefiller server in the heap."""
|
|
server = self.prefillers[server_idx]
|
|
# Priority based on active_tokens and active_kv_cache
|
|
priority = server.active_tokens + server.active_kv_cache * 0.3
|
|
# Remove old entry and add new one
|
|
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap if i != server_idx]
|
|
heapq.heappush(self.prefiller_heap, (priority, server_idx, server))
|
|
|
|
def _update_decoder_priority(self, server_idx: int):
|
|
"""Update the priority of a decoder server in the heap."""
|
|
server = self.decoders[server_idx]
|
|
priority = server.active_tokens
|
|
# Remove old entry and add new one
|
|
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap if i != server_idx]
|
|
heapq.heappush(self.decoder_heap, (priority, server_idx, server))
|
|
|
|
def abort_prefiller_request(self, server_idx: int, request_id): # Changed to synchronous
|
|
"""
|
|
Mark a request as aborted. This will helps to release kv cache in
|
|
prefiller node.
|
|
"""
|
|
# No lock needed - atomic operation
|
|
if server_idx >= len(self.prefillers):
|
|
return
|
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
|
|
|
def acquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
|
"""
|
|
Get the set of aborted requests and clear it.
|
|
This is used to release kv cache in prefiller node.
|
|
"""
|
|
# No lock needed - atomic operation
|
|
if server_idx >= len(self.prefillers):
|
|
return set()
|
|
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
|
self.prefillers[server_idx].aborted_requests.clear()
|
|
return aborted_requests
|
|
|
|
async def next_req_id(self):
|
|
async with self.req_id_lock:
|
|
return str(uuid.uuid4())
|
|
|
|
def select_prefiller(self, token_count): # Changed to synchronous
|
|
# No lock needed - entire function is atomic
|
|
if not self.prefiller_heap:
|
|
raise RuntimeError("No prefiller servers available")
|
|
|
|
priority, chosen, server = heapq.heappop(self.prefiller_heap)
|
|
|
|
# Update the chosen server atomically
|
|
self.prefillers[chosen].active_tokens += token_count
|
|
self.prefillers[chosen].active_kv_cache += token_count
|
|
|
|
# Update priority and re-add to heap
|
|
self._update_prefiller_priority(chosen)
|
|
|
|
return chosen
|
|
|
|
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
if idx >= len(self.prefillers):
|
|
return
|
|
self.prefillers[idx].active_tokens -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_prefiller_priority(idx)
|
|
|
|
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
if idx >= len(self.prefillers):
|
|
return
|
|
if self.prefillers[idx].active_kv_cache > 0:
|
|
self.prefillers[idx].active_kv_cache -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_prefiller_priority(idx)
|
|
|
|
def select_decoder(self, token_count): # Changed to synchronous
|
|
# No lock needed - entire function is atomic
|
|
if not self.decoder_heap:
|
|
raise RuntimeError("No decoder servers available")
|
|
|
|
priority, chosen, server = heapq.heappop(self.decoder_heap)
|
|
|
|
# Update the chosen server atomically
|
|
self.decoders[chosen].active_tokens += token_count
|
|
|
|
# Update priority and re-add to heap
|
|
self._update_decoder_priority(chosen)
|
|
|
|
return chosen
|
|
|
|
def release_decoder(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
if idx >= len(self.decoders):
|
|
return
|
|
self.decoders[idx].active_tokens -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_decoder_priority(idx)
|
|
|
|
# Omni_infer's calculate_input_scores function
|
|
def calculate_prefill_scores(self, request_length: int) -> float:
|
|
length_score = request_length / 4.0
|
|
input_score = length_score * 0.0345 + 120.0745
|
|
return input_score
|
|
|
|
def calculate_decode_scores(self, request_length: int) -> float:
|
|
return request_length
|
|
|
|
async def add_instances(self, instance_type: str, instances: list[ServerState]) -> tuple[list[str], list[str]]:
|
|
added_nodes, waiting_nodes = [], []
|
|
for server in instances:
|
|
is_valid = await self.node_listener.check_instance_status(server.client)
|
|
if is_valid and instance_type == InstanceType.PREFILL:
|
|
self.add_prefillers([server])
|
|
added_nodes.append(str(server))
|
|
elif is_valid and instance_type == InstanceType.DECODE:
|
|
self.add_decoders([server])
|
|
added_nodes.append(str(server))
|
|
else:
|
|
node = str(server)
|
|
self.node_listener.waiting_nodes[node] = (instance_type, server, 0)
|
|
waiting_nodes.append(node)
|
|
return added_nodes, waiting_nodes
|
|
|
|
def add_prefillers(self, instances: list[ServerState]) -> None:
|
|
for server in instances:
|
|
if server in self.tainted_prefillers:
|
|
self.tainted_prefillers.remove(server)
|
|
self.prefiller_heap = [
|
|
(0, idx, server) if srv == server else (priority, idx, srv)
|
|
for priority, idx, srv in self.prefiller_heap
|
|
]
|
|
heapq.heapify(self.prefiller_heap)
|
|
elif server not in self.prefillers:
|
|
self.prefillers.append(server)
|
|
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
|
heapq.heappush(self.prefiller_heap, (0, len(self.prefillers) - 1, server))
|
|
self.print_status(f"Add prefiller instances: {instances}.")
|
|
|
|
def add_decoders(self, instances: list[ServerState]) -> None:
|
|
for server in instances:
|
|
if server in self.tainted_decoders:
|
|
self.tainted_decoders.remove(server)
|
|
self.decoder_heap = [
|
|
(0, idx, server) if srv == server else (priority, idx, srv)
|
|
for priority, idx, srv in self.decoder_heap
|
|
]
|
|
heapq.heapify(self.decoder_heap)
|
|
elif server not in self.decoders:
|
|
self.decoders.append(server)
|
|
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
|
heapq.heappush(self.decoder_heap, (0, len(self.decoders) - 1, server))
|
|
self.print_status(f"Add decoder instances: {instances}.")
|
|
|
|
def remove_prefillers(self, instances: list[ServerState]) -> bool:
|
|
if not instances:
|
|
return False
|
|
|
|
if self.request_num > 0:
|
|
logger.warning(f"Start to taint prefill instances {instances}.")
|
|
self._taint_prefillers(instances)
|
|
return True
|
|
|
|
instances_to_remove = set(instances)
|
|
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
|
prefiller_heap_copy = self.prefiller_heap.copy()
|
|
prefiller_heap_copy.sort(key=lambda x: x[1]) # sorted by key: prefiller_idx
|
|
prefiller_heap = []
|
|
idx = 0
|
|
for priority, _, server in prefiller_heap_copy:
|
|
if server not in instances_to_remove:
|
|
prefiller_heap.append((priority, idx, server))
|
|
idx += 1
|
|
|
|
# prefiller_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
|
self.prefiller_heap = prefiller_heap
|
|
heapq.heapify(self.prefiller_heap)
|
|
self.print_status(f"Remove prefiller instances: {instances}.")
|
|
return False
|
|
|
|
def remove_decoders(self, instances: list[ServerState]) -> bool:
|
|
if not instances:
|
|
return False
|
|
|
|
if self.request_num > 0:
|
|
logger.warning(f"Start to taint decode instances {instances}.")
|
|
self._taint_decoders(instances)
|
|
return True
|
|
|
|
instances_to_remove = set(instances)
|
|
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
|
decoder_heap_copy = self.decoder_heap.copy()
|
|
decoder_heap_copy.sort(key=lambda x: x[1]) # sorted by key: decoder_idx
|
|
decoder_heap = []
|
|
idx = 0
|
|
for priority, _, server in decoder_heap_copy:
|
|
if server not in instances_to_remove:
|
|
decoder_heap.append((priority, idx, server))
|
|
idx += 1
|
|
|
|
# decoder_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
|
self.decoder_heap = decoder_heap
|
|
heapq.heapify(self.decoder_heap)
|
|
self.print_status(f"Remove decoder instances: {instances}.")
|
|
return False
|
|
|
|
def _taint_prefillers(self, instances: list[ServerState]) -> None:
|
|
instances_to_taint = set(instances)
|
|
for server in self.prefillers:
|
|
if server in instances_to_taint and server not in self.tainted_prefillers:
|
|
self.tainted_prefillers.append(server)
|
|
|
|
self.prefiller_heap = [
|
|
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
|
|
for priority, idx, srv in self.prefiller_heap
|
|
]
|
|
heapq.heapify(self.prefiller_heap)
|
|
|
|
def _taint_decoders(self, instances: list[ServerState]) -> None:
|
|
instances_to_taint = set(instances)
|
|
for server in self.decoders:
|
|
if server in instances_to_taint and server not in self.tainted_decoders:
|
|
self.tainted_decoders.append(server)
|
|
|
|
self.decoder_heap = [
|
|
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
|
|
for priority, idx, srv in self.decoder_heap
|
|
]
|
|
heapq.heapify(self.decoder_heap)
|
|
|
|
def print_status(self, msg: str) -> None:
|
|
status = {
|
|
"prefill_instances": [str(server) for server in self.prefillers],
|
|
"decode_instances": [str(server) for server in self.decoders],
|
|
}
|
|
print(f"{msg} Status: {status}")
|
|
|
|
|
|
proxy_state = None
|
|
|
|
|
|
class NodeListener:
|
|
def __init__(self, proxy):
|
|
self.proxy_state = proxy
|
|
self.waiting_nodes: dict[str, tuple[str, Any, int]] = {}
|
|
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
|
|
self.listening_thread.start()
|
|
|
|
def _node_listener(self) -> None:
|
|
while True:
|
|
for node, (instance_type, server, check_times) in list(self.waiting_nodes.items()):
|
|
is_valid = asyncio.run(self.check_instance_status(server.client))
|
|
print(f"Checking instance {node}...")
|
|
check_times += 1
|
|
if is_valid:
|
|
if instance_type == InstanceType.PREFILL:
|
|
self.proxy_state.add_prefillers([server])
|
|
else:
|
|
self.proxy_state.add_decoders([server])
|
|
self.waiting_nodes.pop(node)
|
|
elif check_times == global_args.max_waiting_retries:
|
|
print(f"Instance {node} was not added to the proxy.")
|
|
self.waiting_nodes.pop(node)
|
|
else:
|
|
self.waiting_nodes[node] = (instance_type, server, check_times)
|
|
|
|
if self.proxy_state.tainted_prefillers and not self.proxy_state.request_num:
|
|
need_waiting = self.proxy_state.remove_prefillers(self.proxy_state.tainted_prefillers)
|
|
if not need_waiting:
|
|
self.proxy_state.tainted_prefillers.clear()
|
|
|
|
if self.proxy_state.tainted_decoders and not self.proxy_state.request_num:
|
|
need_waiting = self.proxy_state.remove_decoders(self.proxy_state.tainted_decoders)
|
|
if not need_waiting:
|
|
self.proxy_state.tainted_decoders.clear()
|
|
time.sleep(global_args.waiting_retry_interval)
|
|
|
|
@staticmethod
|
|
async def check_instance_status(client: httpx.AsyncClient) -> bool:
|
|
endpoint = "/models"
|
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
|
try:
|
|
response = await client.get(endpoint, headers=headers)
|
|
response.raise_for_status()
|
|
return True
|
|
except (httpx.RequestError, httpx.HTTPStatusError):
|
|
return False
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--port", type=int, default=8000)
|
|
parser.add_argument("--host", type=str, default="localhost")
|
|
parser.add_argument("--prefiller-hosts", type=str, nargs="+", default=["localhost"])
|
|
parser.add_argument("--prefiller-ports", type=int, nargs="+", default=[8001])
|
|
parser.add_argument("--decoder-hosts", type=str, nargs="+", default=["localhost"])
|
|
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
|
parser.add_argument("--max-retries", type=int, default=3, help="Maximum number of retries for HTTP requests")
|
|
parser.add_argument(
|
|
"--retry-delay", type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries"
|
|
)
|
|
parser.add_argument(
|
|
"--max-waiting-retries", type=int, default=3, help="Maximum number of retries for waiting nodes to be started"
|
|
)
|
|
parser.add_argument(
|
|
"--waiting-retry-interval",
|
|
type=float,
|
|
default=10,
|
|
help="Check interval (seconds) for waiting nodes to be started",
|
|
)
|
|
args = parser.parse_args()
|
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
|
raise ValueError("Number of prefiller hosts must match number of prefiller ports")
|
|
if len(args.decoder_hosts) != len(args.decoder_ports):
|
|
raise ValueError("Number of decoder hosts must match number of decoder ports")
|
|
args.prefiller_instances = list(zip(args.prefiller_hosts, args.prefiller_ports))
|
|
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
|
return args
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global proxy_state
|
|
proxy_state = ProxyState(global_args.prefiller_instances, global_args.decoder_instances)
|
|
print(f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients.")
|
|
yield
|
|
for p in proxy_state.prefillers:
|
|
await p.client.aclose()
|
|
for d in proxy_state.decoders:
|
|
await d.client.aclose()
|
|
|
|
|
|
async def listen_for_disconnect(request: Request) -> None:
|
|
"""Return if a disconnect message is received"""
|
|
while True:
|
|
message = await request.receive()
|
|
if message["type"] == "http.disconnect":
|
|
break
|
|
|
|
|
|
def with_cancellation(handler_func):
|
|
@functools.wraps(handler_func)
|
|
async def wrapper(*args, **kwargs):
|
|
request = kwargs["request"]
|
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
|
done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
|
|
for task in pending:
|
|
task.cancel()
|
|
if handler_task in done:
|
|
return handler_task.result()
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
async def send_request_to_service(
|
|
client: httpx.AsyncClient,
|
|
prefiller_id: int,
|
|
endpoint: str,
|
|
req_data: dict,
|
|
request_id: str,
|
|
max_retries: int = 3,
|
|
base_delay: float = 0.2,
|
|
):
|
|
aborted_requests = proxy_state.acquire_aborted_prefiller_requests(prefiller_id)
|
|
req_data = req_data.copy()
|
|
req_data["kv_transfer_params"] = {
|
|
"do_remote_decode": True,
|
|
"do_remote_prefill": False,
|
|
"remote_engine_id": None,
|
|
"remote_block_ids": None,
|
|
"remote_host": None,
|
|
"remote_port": None,
|
|
"aborted_request": list(aborted_requests),
|
|
}
|
|
req_data["stream"] = False
|
|
req_data["max_tokens"] = 1
|
|
req_data["min_tokens"] = 1
|
|
if "max_completion_tokens" in req_data:
|
|
req_data["max_completion_tokens"] = 1
|
|
if "stream_options" in req_data:
|
|
del req_data["stream_options"]
|
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
|
last_exc = None
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
response = await client.post(endpoint, json=req_data, headers=headers)
|
|
response.raise_for_status()
|
|
return response
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
logger.warning(f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
|
last_exc = e
|
|
if attempt < max_retries:
|
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
|
else:
|
|
logger.error(f"All {max_retries} attempts failed for {endpoint}.")
|
|
raise last_exc
|
|
|
|
|
|
async def stream_service_response_with_retry(
|
|
client: httpx.AsyncClient,
|
|
endpoint: str,
|
|
req_data: dict,
|
|
request_id: str,
|
|
max_retries: int = 3,
|
|
base_delay: float = 0.2,
|
|
):
|
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id}
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
async with client.stream("POST", endpoint, json=req_data, headers=headers) as response:
|
|
response.raise_for_status()
|
|
first_chunk_sent = False
|
|
async for chunk in response.aiter_bytes():
|
|
first_chunk_sent = True
|
|
yield chunk
|
|
return # Success, exit after streaming
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
if attempt < max_retries:
|
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
|
else:
|
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
|
raise e
|
|
except Exception as e:
|
|
# If any chunk has been sent, do not retry, just log and drop
|
|
if "first_chunk_sent" in locals() and first_chunk_sent:
|
|
logger.error(f"Streaming to client interrupted after response started: {str(e)}")
|
|
return
|
|
else:
|
|
if attempt < max_retries:
|
|
logger.warning(f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}")
|
|
await asyncio.sleep(base_delay * (2 ** (attempt - 1)))
|
|
else:
|
|
logger.error(f"All {max_retries} attempts failed for streaming {endpoint}.")
|
|
raise e
|
|
|
|
|
|
async def _handle_select_instance(api: str, req_data: Any, request_length: int):
|
|
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
|
logger.debug(f"Request length: {request_length}, Prefiller score: {prefiller_score}")
|
|
request_id = await proxy_state.next_req_id()
|
|
# Select prefiller
|
|
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
|
prefiller = proxy_state.prefillers[prefiller_idx]
|
|
# Send request to prefiller
|
|
response = await send_request_to_service(
|
|
prefiller.client,
|
|
prefiller_idx,
|
|
api,
|
|
req_data,
|
|
request_id,
|
|
max_retries=global_args.max_retries,
|
|
base_delay=global_args.retry_delay,
|
|
)
|
|
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
|
response_json = response.json()
|
|
kv_transfer_params = response_json.get("kv_transfer_params", {})
|
|
if kv_transfer_params:
|
|
req_data["kv_transfer_params"] = kv_transfer_params
|
|
# Select decoder
|
|
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
|
logger.debug("Decoder score: %f", decoder_score)
|
|
# Use the prefiller's kv_transfer_params to select decoder
|
|
decoder_idx = proxy_state.select_decoder(decoder_score)
|
|
decoder = proxy_state.decoders[decoder_idx]
|
|
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
|
return InstanceInfo(
|
|
request_id=request_id,
|
|
prefiller_idx=prefiller_idx,
|
|
prefiller_score=prefiller_score,
|
|
prefiller=prefiller,
|
|
decoder=decoder,
|
|
decoder_idx=decoder_idx,
|
|
decoder_score=decoder_score,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class InstanceInfo:
|
|
request_id: str
|
|
prefiller_idx: int
|
|
prefiller_score: float
|
|
prefiller: ServerState
|
|
decoder_idx: int
|
|
decoder_score: float
|
|
decoder: ServerState
|
|
|
|
|
|
async def _handle_completions(api: str, request: Request):
|
|
try:
|
|
proxy_state.request_num += 1
|
|
req_data = await request.json()
|
|
req_body = await request.body()
|
|
request_length = len(req_body)
|
|
instance_info = await _handle_select_instance(api, req_data, request_length)
|
|
stream_flag = bool(req_data.get("stream", False))
|
|
chat_flag = "messages" in req_data
|
|
|
|
if "prompt" in req_data:
|
|
origin_prompt = req_data["prompt"]
|
|
elif chat_flag:
|
|
messages = req_data["messages"]
|
|
origin_prompt = messages[0].get("content", "")
|
|
else:
|
|
origin_prompt = ""
|
|
# refer to vLLM sampling_params: max_token default value
|
|
origin_max_tokens = req_data.get("max_tokens", 16)
|
|
|
|
async def generate_stream():
|
|
nonlocal instance_info
|
|
generated_token = ""
|
|
released_kv = False
|
|
retry_count = 0
|
|
retry = True
|
|
completion_tokens = 0
|
|
# Only one await per chunk, minimal logic in loop
|
|
try:
|
|
while retry:
|
|
retry = False
|
|
async for chunk in stream_service_response_with_retry(
|
|
instance_info.decoder.client,
|
|
api,
|
|
req_data,
|
|
request_id=instance_info.request_id,
|
|
max_retries=global_args.max_retries,
|
|
base_delay=global_args.retry_delay,
|
|
):
|
|
if not released_kv and chunk:
|
|
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
|
|
released_kv = True
|
|
try:
|
|
chunk_str = chunk.decode("utf-8").strip()
|
|
except UnicodeDecodeError:
|
|
logger.debug(f"Skipping chunk: {chunk}")
|
|
yield chunk
|
|
continue
|
|
if not chunk_str:
|
|
continue
|
|
if chunk_str.startswith("data: "):
|
|
chunk_str = chunk_str[len("data: ") :]
|
|
try:
|
|
chunk_json = json.loads(chunk_str)
|
|
except json.JSONDecodeError:
|
|
# if chunk is [done], skip it.
|
|
logger.debug(f"Skipping chunk: {chunk_str}")
|
|
yield chunk
|
|
continue
|
|
choices = chunk_json.get("choices", [])
|
|
if not choices:
|
|
yield chunk
|
|
continue
|
|
|
|
choice = choices[0]
|
|
delta = choice.get("delta") or {}
|
|
message = choice.get("message") or {}
|
|
content = delta.get("content") or message.get("content") or choice.get("text") or ""
|
|
generated_token += content
|
|
|
|
stop_reason = choice.get("stop_reason")
|
|
usage = chunk_json.get("usage", {})
|
|
completion_tokens = (
|
|
(completion_tokens + 1)
|
|
if stream_flag
|
|
else (completion_tokens + usage.get("completion_tokens"))
|
|
)
|
|
if stop_reason == "recomputed":
|
|
retry = True
|
|
retry_count += 1
|
|
if chat_flag:
|
|
messages[0]["content"] = origin_prompt + generated_token
|
|
else:
|
|
req_data["prompt"] = origin_prompt + generated_token
|
|
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
|
tmp_request_length = len(json.dumps(req_data).encode("utf-8"))
|
|
instance_info = await _handle_select_instance(api, req_data, tmp_request_length)
|
|
break
|
|
if retry_count > 0 and not stream_flag:
|
|
if chat_flag:
|
|
choice["message"]["content"] = generated_token
|
|
else:
|
|
choice["text"] = generated_token
|
|
chunk = json.dumps(chunk_json).encode("utf-8")
|
|
yield chunk
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} "
|
|
f"the aborted request {instance_info.request_id} will be routing to the target "
|
|
"prefiller when new request is ready to dispatch to it"
|
|
)
|
|
proxy_state.abort_prefiller_request(instance_info.prefiller_idx, instance_info.request_id)
|
|
proxy_state.release_prefiller_kv(instance_info.prefiller_idx, instance_info.prefiller_score)
|
|
|
|
# After streaming done, release tokens
|
|
proxy_state.release_decoder(instance_info.decoder_idx, instance_info.decoder_score)
|
|
|
|
# Determine the correct media type based on stream flag
|
|
media_type = "text/event-stream; charset=utf-8" if stream_flag else "application/json"
|
|
return StreamingResponse(generate_stream(), media_type=media_type)
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
exc_info = sys.exc_info()
|
|
print(f"Error occurred in disagg prefill proxy server - {api} endpoint")
|
|
print(e)
|
|
print("".join(traceback.format_exception(*exc_info)))
|
|
raise
|
|
finally:
|
|
proxy_state.request_num -= 1
|
|
|
|
|
|
async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
|
try:
|
|
req_data = await request.json()
|
|
instance_type = req_data.get("type", "")
|
|
instances = req_data.get("instances", [])
|
|
if isinstance(instances, str):
|
|
instances = [instances]
|
|
instances = trans_instances(instances)
|
|
all_msg = f"{adjust_mode} {instance_type} instances: {[str(server) for server in instances]}."
|
|
|
|
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
|
|
return {
|
|
"error": f"Instance type {instance_type} is not supported. "
|
|
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."
|
|
}
|
|
|
|
if adjust_mode == "add":
|
|
added_nodes, waiting_nodes = await proxy_state.add_instances(instance_type, instances)
|
|
if waiting_nodes:
|
|
all_msg = (
|
|
f"{adjust_mode} {instance_type} instances: {added_nodes}. "
|
|
f"Instances {waiting_nodes} are waiting to be added."
|
|
)
|
|
elif adjust_mode == "remove":
|
|
if instance_type == InstanceType.PREFILL:
|
|
need_waiting = proxy_state.remove_prefillers(instances)
|
|
else:
|
|
need_waiting = proxy_state.remove_decoders(instances)
|
|
|
|
if need_waiting:
|
|
all_msg = f"Instances {instances} are isolated and waiting to be removed."
|
|
return {
|
|
"message": all_msg,
|
|
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
|
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders],
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Failed to {adjust_mode} instances: {e}")
|
|
raise e
|
|
|
|
|
|
def trans_instances(instances: list[str]) -> list[ServerState]:
|
|
server_list = []
|
|
for instance in instances:
|
|
h, p = instance.split(":")
|
|
server_list.append(ServerState(h, int(p)))
|
|
return server_list
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
@with_cancellation
|
|
async def handle_completions(request: Request):
|
|
return await _handle_completions("/completions", request)
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@with_cancellation
|
|
async def handle_chat_completions(request: Request):
|
|
return await _handle_completions("/chat/completions", request)
|
|
|
|
|
|
@app.get("/healthcheck")
|
|
async def healthcheck():
|
|
return {
|
|
"status": "ok",
|
|
"prefill_instances": len(proxy_state.prefillers),
|
|
"decode_instances": len(proxy_state.decoders),
|
|
}
|
|
|
|
|
|
@app.post("/instances/add")
|
|
async def handle_add_instances(request: Request):
|
|
return await _handle_adjust_instances("add", request)
|
|
|
|
|
|
@app.post("/instances/remove")
|
|
async def handle_remove_instances(request: Request):
|
|
return await _handle_adjust_instances("remove", request)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
global global_args
|
|
global_args = parse_args()
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|