**[RFC]: Elastic Scaling Support for P/D Instances Based on KV Pool:**
https://github.com/vllm-project/vllm-ascend/issues/3380
### What this PR does / why we need it?
Support elastic scaling for P/D instances based on mooncake conncetor
deplayment.
**Support API routes**
* `/instances/add`: add prefill nodes or decode nodes to the list.
* `/instances/remove`: remove prefill nodes or decode nodes from the
list.
**Support functions**
* Support **adding** prefill nodes or decode nodes.
- If prefill or decode server deployed **after the proxy deployed**,
server can use `/instances/add` API to join the proxy server. The
prefill server or decode server sends a signal to the proxy server, and
the proxy server will check the status of the node util the node is
available.
* Support **removing** prefill nodes or decode nodes:
- Support using `/instances/remove` API to **delete the node** from the
proxy server.
### Does this PR introduce _any_ user-facing change?
For
`examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`:
**Add 2 params**
When adding nodes to the proxy, the proxy will wait the nodes to be
started util retrying a certain of times.
| name | type | default | help |
| ----- | ---- | ---- | ---- |
| max-waiting-retries | int | 3 | Maximum number of retries for waiting
nodes to be started |
| waiting-retry-interval | float | 10 | Check interval (seconds) for
waiting nodes to be started |
For example:
```shell
python load_balance_proxy_server_example.py \
--host 0.0.0.0 --port 9000 \
--prefiller-hosts 127.0.0.1 127.0.0.1 \
--prefiller-ports 8100 8101 \
--decoder-hosts 127.0.0.1 127.0.0.1 \
--decoder-ports 8200 8201 \
--max-waiting-retries 3 \
--waiting-retry-interval 10
```
**Add 2 API routings**
* Add instances: `instances/add`
For example, add 2 prefiller instances:
```shell
curl -X POST http://localhost:9000/instances/add \
-H "Content-Type: application/json" \
-d '{
"type": "prefill",
"instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
}'
```
Response:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102', '127.0.0.1:8103'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102', '127.0.0.1:8103'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
If the node '127.0.0.1:8103' has not benn started:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102']. Instances ['127.0.0.1:8103'] are waiting to be added.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
* Remove instances: `instances/remove`
For example, remove 1 decoder instance:
```shell
curl -X POST http://localhost:9000/instances/remove \
-H "Content-Type: application/json" \
-d '{
"type": "decode",
"instances": "127.0.0.1:8201"
}'
```
Response:
```shell
{"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']}
```
### How was this patch tested?
Run proxy and using `/instances/add` API to add nodes and
`/instances/remove` API to remove nodes
* vLLM version: v0.11.0.rc3
* vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0.rc3
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
898 lines
35 KiB
Python
898 lines
35 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Tutorial: Using the Load Balance Proxy Server Example
|
|
#
|
|
# This proxy server is designed to distribute requests between multiple
|
|
# "prefiller" and "decoder" backend servers for large language model inference.
|
|
# It is useful for scaling out inference workloads and balancing load across
|
|
# multiple backend instances.
|
|
#
|
|
# Features:
|
|
# - Load balances requests to multiple prefiller and decoder servers.
|
|
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
|
|
# - Streams responses from backend servers to clients.
|
|
#
|
|
# Prerequisites:
|
|
# - Python 3.10+
|
|
# - Install dependencies:
|
|
# pip install fastapi<0.124.0 httpx uvicorn vllm
|
|
#
|
|
# Step 1: Start Your Backend Servers
|
|
# ----------------------------------
|
|
# You need to have at least one prefiller and one decoder backend running.
|
|
# These can be mock servers or actual vLLM servers.
|
|
#
|
|
# For testing, you can use the provided mock server:
|
|
#
|
|
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
|
|
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
|
|
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
|
|
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
|
|
#
|
|
# Step 2: Start the Proxy Server
|
|
# ------------------------------
|
|
# Run the proxy server, specifying the host/port for each prefiller and decoder:
|
|
#
|
|
# python load_balance_proxy_server_example.py \
|
|
# --host 0.0.0.0 --port 9000 \
|
|
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
|
|
# --prefiller-ports 8100 8101 \
|
|
# --decoder-hosts 127.0.0.1 127.0.0.1 \
|
|
# --decoder-ports 8200 8201
|
|
#
|
|
# This will start the proxy on port 9000, load balancing between two prefiller
|
|
# and two decoder servers.
|
|
#
|
|
# Step 3: Send a Request to the Proxy
|
|
# -----------------------------------
|
|
# You can now send OpenAI-compatible requests to the proxy. For example:
|
|
#
|
|
# curl -X POST http://localhost:9000/v1/completions \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "model": "your-model",
|
|
# "prompt": "The quick brown fox jumps over the lazy dog",
|
|
# "max_tokens": 16
|
|
# }'
|
|
#
|
|
# Or for chat completions:
|
|
#
|
|
# curl -X POST http://localhost:9000/v1/chat/completions \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "model": "your-model",
|
|
# "messages": [{"role": "user", "content": "Hello!"}],
|
|
# "max_tokens": 16
|
|
# }'
|
|
#
|
|
# Step 4: Health Check
|
|
# --------------------
|
|
# To check if the proxy is running and see how many backend instances are
|
|
# connected, use:
|
|
#
|
|
# curl http://localhost:9000/healthcheck
|
|
#
|
|
# This will return a JSON object with the status and the number of prefiller
|
|
# and decoder instances.
|
|
#
|
|
# Step 5: Add or Remove Prefiller or Decoder Instances (Optional)
|
|
# ---------------------------------------------------------------
|
|
# You can add or remove prefiller or decoder instances after the proxy is started.
|
|
# For example, add 2 prefiller instances:
|
|
#
|
|
# curl -X POST http://localhost:9000/instances/add \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "type": "prefill",
|
|
# "instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
|
|
# }'
|
|
#
|
|
# or remove 1 decoder instance:
|
|
#
|
|
# curl -X POST http://localhost:9000/instances/remove \
|
|
# -H "Content-Type: application/json" \
|
|
# -d '{
|
|
# "type": "decode",
|
|
# "instances": "127.0.0.1:8201"
|
|
# }'
|
|
#
|
|
# This will return a JSON object with the adding or removing info
|
|
# and the current prefiller and decoder instances.
|
|
#
|
|
# When adding instances, if the instances are not started,
|
|
# the proxy will wait and try until the instances to be started
|
|
# or exceeding the number of attempts
|
|
#
|
|
# Notes:
|
|
# - You can scale the number of prefiller and decoder servers as needed.
|
|
# - The proxy will round-robin requests to balance load.
|
|
# - For production, ensure your backend servers are robust and secure.
|
|
#
|
|
# For more details, see the code and comments in this file.
|
|
|
|
import argparse
|
|
import asyncio
|
|
import functools
|
|
import heapq
|
|
import ipaddress
|
|
import json
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Tuple, Dict
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
# Add uvloop for faster event loop if available
|
|
try:
|
|
import uvloop
|
|
|
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class InstanceType:
|
|
PREFILL: str = "prefill"
|
|
DECODE: str = "decode"
|
|
|
|
|
|
class ServerState:
|
|
|
|
def __init__(self, host, port):
|
|
self.host = host
|
|
self.port = port
|
|
self.url = f'http://{host}:{port}/v1'
|
|
try:
|
|
ip = ipaddress.ip_address(self.host)
|
|
if isinstance(ip, ipaddress.IPv6Address):
|
|
self.url = f'http://[{host}]:{port}/v1'
|
|
except Exception:
|
|
pass
|
|
self.client = httpx.AsyncClient(timeout=None,
|
|
base_url=self.url,
|
|
limits=httpx.Limits(
|
|
max_connections=100000,
|
|
max_keepalive_connections=100000))
|
|
self.active_tokens = 0
|
|
self.active_kv_cache = 0 # Only for prefiller
|
|
self.active_requests = 0 # Number of active requests
|
|
self.aborted_requests = set() # Track aborted requests
|
|
# Removed individual server lock - will use global locks instead
|
|
|
|
def __eq__(self, other):
|
|
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
|
other_host = other.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
|
return self_host == other_host and str(self.port) == str(other.port)
|
|
|
|
def __hash__(self):
|
|
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
|
return hash((self_host, str(self.port)))
|
|
|
|
def __repr__(self):
|
|
return f"{self.host}:{self.port}"
|
|
|
|
|
|
class ProxyState:
|
|
|
|
def __init__(self, prefiller_instances, decoder_instances):
|
|
self.node_listener = NodeListener(self)
|
|
|
|
self.prefillers: List[ServerState] = [
|
|
ServerState(h, p) for h, p in prefiller_instances
|
|
]
|
|
self.decoders: List[ServerState] = [
|
|
ServerState(h, p) for h, p in decoder_instances
|
|
]
|
|
self.req_to_prefiller = {}
|
|
self.req_id_lock = asyncio.Lock()
|
|
# Removed selection locks - no longer needed for synchronous methods
|
|
|
|
# Initialize priority queues for efficient server selection
|
|
# Each entry is (priority_score, server_index, server_reference)
|
|
# Lower priority score = higher priority (less loaded)
|
|
self.prefiller_heap = [(0, i, server)
|
|
for i, server in enumerate(self.prefillers)]
|
|
self.decoder_heap = [(0, i, server)
|
|
for i, server in enumerate(self.decoders)]
|
|
heapq.heapify(self.prefiller_heap)
|
|
heapq.heapify(self.decoder_heap)
|
|
|
|
def _update_prefiller_priority(self, server_idx: int):
|
|
"""Update the priority of a prefiller server in the heap."""
|
|
server = self.prefillers[server_idx]
|
|
# Priority based on active_tokens and active_kv_cache
|
|
priority = server.active_tokens + server.active_kv_cache * 0.3
|
|
# Remove old entry and add new one
|
|
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
|
if i != server_idx]
|
|
heapq.heappush(self.prefiller_heap,
|
|
(priority, server_idx, server)) # type: ignore
|
|
|
|
def _update_decoder_priority(self, server_idx: int):
|
|
"""Update the priority of a decoder server in the heap."""
|
|
server = self.decoders[server_idx]
|
|
priority = server.active_tokens
|
|
# Remove old entry and add new one
|
|
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
|
if i != server_idx]
|
|
heapq.heappush(self.decoder_heap,
|
|
(priority, server_idx, server)) # type: ignore
|
|
|
|
def abort_prefiller_request(self, server_idx: int,
|
|
request_id): # Changed to synchronous
|
|
"""
|
|
Mark a request as aborted. This will helps to release kv cache in
|
|
prefiller node.
|
|
"""
|
|
# No lock needed - atomic operation
|
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
|
|
|
def aquire_aborted_prefiller_requests(
|
|
self, server_idx: int): # Changed to synchronous
|
|
"""
|
|
Get the set of aborted requests and clear it.
|
|
This is used to release kv cache in prefiller node.
|
|
"""
|
|
# No lock needed - atomic operation
|
|
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
|
self.prefillers[server_idx].aborted_requests.clear()
|
|
return aborted_requests
|
|
|
|
async def next_req_id(self):
|
|
async with self.req_id_lock:
|
|
return str(uuid.uuid4())
|
|
|
|
def select_prefiller(self, token_count): # Changed to synchronous
|
|
# No lock needed - entire function is atomic
|
|
if not self.prefiller_heap:
|
|
raise RuntimeError("No prefiller servers available")
|
|
|
|
priority, chosen, server = heapq.heappop(self.prefiller_heap)
|
|
|
|
# Update the chosen server atomically
|
|
self.prefillers[chosen].active_tokens += token_count
|
|
self.prefillers[chosen].active_kv_cache += token_count
|
|
|
|
# Update priority and re-add to heap
|
|
self._update_prefiller_priority(chosen)
|
|
|
|
return chosen
|
|
|
|
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
self.prefillers[idx].active_tokens -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_prefiller_priority(idx)
|
|
|
|
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
if self.prefillers[idx].active_kv_cache > 0:
|
|
self.prefillers[idx].active_kv_cache -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_prefiller_priority(idx)
|
|
|
|
def select_decoder(self, token_count): # Changed to synchronous
|
|
# No lock needed - entire function is atomic
|
|
if not self.decoder_heap:
|
|
raise RuntimeError("No decoder servers available")
|
|
|
|
priority, chosen, server = heapq.heappop(self.decoder_heap)
|
|
|
|
# Update the chosen server atomically
|
|
self.decoders[chosen].active_tokens += token_count
|
|
|
|
# Update priority and re-add to heap
|
|
self._update_decoder_priority(chosen)
|
|
|
|
return chosen
|
|
|
|
def release_decoder(self, idx, token_count): # Changed to synchronous
|
|
# No lock needed - atomic operation
|
|
self.decoders[idx].active_tokens -= token_count
|
|
# Update priority queue after releasing
|
|
self._update_decoder_priority(idx)
|
|
|
|
# Omni_infer's calculate_input_scores function
|
|
def calculate_prefill_scores(self, request_length: int) -> float:
|
|
length_score = request_length / 4.0
|
|
input_score = length_score * 0.0345 + 120.0745
|
|
return input_score
|
|
|
|
def calculate_decode_scores(self, request_length: int) -> float:
|
|
return request_length
|
|
|
|
async def add_instances(
|
|
self, instance_type: str, instances: List[ServerState]
|
|
) -> Tuple[List[str], List[str]]:
|
|
added_nodes, waiting_nodes = [], []
|
|
for server in instances:
|
|
is_valid = await self.node_listener.check_instance_status(server.client)
|
|
if is_valid and instance_type == InstanceType.PREFILL:
|
|
self.add_prefillers([server])
|
|
added_nodes.append(str(server))
|
|
elif is_valid and instance_type == InstanceType.DECODE:
|
|
self.add_decoders([server])
|
|
added_nodes.append(str(server))
|
|
else:
|
|
node = str(server)
|
|
self.node_listener.waiting_nodes[node] = (instance_type, server, 0)
|
|
waiting_nodes.append(node)
|
|
return added_nodes, waiting_nodes
|
|
|
|
def add_prefillers(self, instances: List[ServerState]) -> None:
|
|
num_prefillers = len(self.prefillers)
|
|
for idx, server in enumerate(instances):
|
|
if server not in self.prefillers:
|
|
self.prefillers.append(server)
|
|
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
|
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
|
self.print_status(f"Add prefiller instances: {instances}.")
|
|
|
|
def add_decoders(self, instances: List[ServerState]) -> None:
|
|
num_decoders = len(self.decoders)
|
|
for idx, server in enumerate(instances):
|
|
if server not in self.decoders:
|
|
self.decoders.append(server)
|
|
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
|
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
|
self.print_status(f"Add decoder instances: {instances}.")
|
|
|
|
def remove_prefillers(self, instances: List[ServerState]) -> None:
|
|
instances_to_remove = set(instances)
|
|
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
|
prefiller_heap_copy = self.prefiller_heap.copy()
|
|
prefiller_heap_copy.sort(key=lambda x: x[1]) # sorted by key: prefiller_idx
|
|
prefiller_heap = []
|
|
idx = 0
|
|
for priority, _, server in prefiller_heap_copy:
|
|
if server not in instances_to_remove:
|
|
prefiller_heap.append((priority, idx, server))
|
|
idx += 1
|
|
|
|
# prefiller_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
|
self.prefiller_heap = prefiller_heap
|
|
heapq.heapify(self.prefiller_heap)
|
|
self.print_status(f"Remove prefiller instances: {instances}.")
|
|
|
|
def remove_decoders(self, instances: List[ServerState]) -> None:
|
|
instances_to_remove = set(instances)
|
|
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
|
decoder_heap_copy = self.decoder_heap.copy()
|
|
decoder_heap_copy.sort(key=lambda x: x[1]) # sorted by key: decoder_idx
|
|
decoder_heap = []
|
|
idx = 0
|
|
for priority, _, server in decoder_heap_copy:
|
|
if server not in instances_to_remove:
|
|
decoder_heap.append((priority, idx, server))
|
|
idx += 1
|
|
|
|
# decoder_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
|
self.decoder_heap = decoder_heap
|
|
heapq.heapify(self.decoder_heap)
|
|
self.print_status(f"Remove decoder instances: {instances}.")
|
|
|
|
def print_status(self, msg: str) -> None:
|
|
status = {
|
|
"prefill_instances": [str(server) for server in self.prefillers],
|
|
"decode_instances": [str(server) for server in self.decoders]
|
|
}
|
|
print(f"{msg} Status: {status}")
|
|
|
|
|
|
proxy_state = None
|
|
|
|
|
|
class NodeListener:
|
|
def __init__(self, proxy):
|
|
self.proxy_state = proxy
|
|
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
|
|
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
|
|
self.listening_thread.start()
|
|
|
|
def _node_listener(self) -> None:
|
|
while True:
|
|
for node, (instance_type, server, check_times) in list(self.waiting_nodes.items()):
|
|
is_valid = asyncio.run(self.check_instance_status(server.client))
|
|
print(f"Checking instance {node}...")
|
|
check_times += 1
|
|
if is_valid:
|
|
if instance_type == InstanceType.PREFILL:
|
|
self.proxy_state.add_prefillers([server])
|
|
else:
|
|
self.proxy_state.add_decoders([server])
|
|
self.waiting_nodes.pop(node)
|
|
elif check_times == global_args.max_waiting_retries:
|
|
print(f"Instance {node} was not added to the proxy.")
|
|
self.waiting_nodes.pop(node)
|
|
else:
|
|
self.waiting_nodes[node] = (instance_type, server, check_times)
|
|
time.sleep(global_args.waiting_retry_interval)
|
|
|
|
@staticmethod
|
|
async def check_instance_status(client: httpx.AsyncClient) -> bool:
|
|
endpoint = "/models"
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
|
}
|
|
try:
|
|
response = await client.get(endpoint, headers=headers)
|
|
response.raise_for_status()
|
|
return True
|
|
except (httpx.RequestError, httpx.HTTPStatusError):
|
|
return False
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--port", type=int, default=8000)
|
|
parser.add_argument("--host", type=str, default="localhost")
|
|
parser.add_argument("--prefiller-hosts",
|
|
type=str,
|
|
nargs="+",
|
|
default=["localhost"])
|
|
parser.add_argument("--prefiller-ports",
|
|
type=int,
|
|
nargs="+",
|
|
default=[8001])
|
|
parser.add_argument("--decoder-hosts",
|
|
type=str,
|
|
nargs="+",
|
|
default=["localhost"])
|
|
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
|
parser.add_argument("--max-retries",
|
|
type=int,
|
|
default=3,
|
|
help="Maximum number of retries for HTTP requests")
|
|
parser.add_argument(
|
|
"--retry-delay",
|
|
type=float,
|
|
default=0.001,
|
|
help="Base delay (seconds) for exponential backoff retries")
|
|
parser.add_argument("--max-waiting-retries",
|
|
type=int,
|
|
default=3,
|
|
help="Maximum number of retries for waiting nodes to be started")
|
|
parser.add_argument(
|
|
"--waiting-retry-interval",
|
|
type=float,
|
|
default=10,
|
|
help="Check interval (seconds) for waiting nodes to be started")
|
|
args = parser.parse_args()
|
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
|
raise ValueError(
|
|
"Number of prefiller hosts must match number of prefiller ports")
|
|
if len(args.decoder_hosts) != len(args.decoder_ports):
|
|
raise ValueError(
|
|
"Number of decoder hosts must match number of decoder ports")
|
|
args.prefiller_instances = list(
|
|
zip(args.prefiller_hosts, args.prefiller_ports))
|
|
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
|
return args
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global proxy_state
|
|
proxy_state = ProxyState(global_args.prefiller_instances,
|
|
global_args.decoder_instances)
|
|
print(
|
|
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
|
)
|
|
yield
|
|
for p in proxy_state.prefillers:
|
|
await p.client.aclose()
|
|
for d in proxy_state.decoders:
|
|
await d.client.aclose()
|
|
|
|
|
|
async def listen_for_disconnect(request: Request) -> None:
|
|
"""Return if a disconnect message is received"""
|
|
while True:
|
|
message = await request.receive()
|
|
if message["type"] == "http.disconnect":
|
|
break
|
|
|
|
|
|
def with_cancellation(handler_func):
|
|
|
|
@functools.wraps(handler_func)
|
|
async def wrapper(*args, **kwargs):
|
|
request = kwargs["request"]
|
|
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
|
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
|
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
|
return_when=asyncio.FIRST_COMPLETED)
|
|
for task in pending:
|
|
task.cancel()
|
|
if handler_task in done:
|
|
return handler_task.result()
|
|
return None
|
|
|
|
return wrapper
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
async def send_request_to_service(client: httpx.AsyncClient,
|
|
prefiller_id: int,
|
|
endpoint: str,
|
|
req_data: dict,
|
|
request_id: str,
|
|
max_retries: int = 3,
|
|
base_delay: float = 0.2):
|
|
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
|
prefiller_id)
|
|
req_data = req_data.copy()
|
|
req_data['kv_transfer_params'] = {
|
|
"do_remote_decode": True,
|
|
"do_remote_prefill": False,
|
|
"remote_engine_id": None,
|
|
"remote_block_ids": None,
|
|
"remote_host": None,
|
|
"remote_port": None,
|
|
"aborted_request": list(aborted_requests),
|
|
}
|
|
req_data["stream"] = False
|
|
req_data["max_tokens"] = 1
|
|
req_data["min_tokens"] = 1
|
|
if "max_completion_tokens" in req_data:
|
|
req_data["max_completion_tokens"] = 1
|
|
if "stream_options" in req_data:
|
|
del req_data["stream_options"]
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
"X-Request-Id": request_id
|
|
}
|
|
last_exc = None
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
response = await client.post(endpoint,
|
|
json=req_data,
|
|
headers=headers)
|
|
response.raise_for_status()
|
|
return response
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
logger.warning(
|
|
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
|
last_exc = e
|
|
if attempt < max_retries:
|
|
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
else:
|
|
logger.error(
|
|
f"All {max_retries} attempts failed for {endpoint}.")
|
|
raise last_exc
|
|
|
|
|
|
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
|
endpoint: str,
|
|
req_data: dict,
|
|
request_id: str,
|
|
max_retries: int = 3,
|
|
base_delay: float = 0.2):
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
"X-Request-Id": request_id
|
|
}
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
async with client.stream("POST",
|
|
endpoint,
|
|
json=req_data,
|
|
headers=headers) as response:
|
|
response.raise_for_status()
|
|
first_chunk_sent = False
|
|
async for chunk in response.aiter_bytes():
|
|
first_chunk_sent = True
|
|
yield chunk
|
|
return # Success, exit after streaming
|
|
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
|
if attempt < max_retries:
|
|
logger.warning(
|
|
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
|
)
|
|
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
else:
|
|
logger.error(
|
|
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
)
|
|
raise e
|
|
except Exception as e:
|
|
# If any chunk has been sent, do not retry, just log and drop
|
|
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
|
logger.error(
|
|
f"Streaming to client interrupted after response started: {str(e)}"
|
|
)
|
|
return
|
|
else:
|
|
if attempt < max_retries:
|
|
logger.warning(
|
|
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
|
)
|
|
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
|
else:
|
|
logger.error(
|
|
f"All {max_retries} attempts failed for streaming {endpoint}."
|
|
)
|
|
raise e
|
|
|
|
|
|
async def _handle_select_instance(api: str, req_data: Any,
|
|
request_length: int):
|
|
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
|
logger.debug(
|
|
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
|
)
|
|
request_id = await proxy_state.next_req_id()
|
|
# Select prefiller
|
|
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
|
prefiller = proxy_state.prefillers[prefiller_idx]
|
|
# Send request to prefiller
|
|
response = await send_request_to_service(
|
|
prefiller.client,
|
|
prefiller_idx,
|
|
api,
|
|
req_data,
|
|
request_id,
|
|
max_retries=global_args.max_retries,
|
|
base_delay=global_args.retry_delay)
|
|
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
|
response_json = response.json()
|
|
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
|
if kv_transfer_params:
|
|
req_data["kv_transfer_params"] = kv_transfer_params
|
|
# Select decoder
|
|
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
|
logger.debug("Decoder score: %f", decoder_score)
|
|
# Use the prefiller's kv_transfer_params to select decoder
|
|
decoder_idx = proxy_state.select_decoder(decoder_score)
|
|
decoder = proxy_state.decoders[decoder_idx]
|
|
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
|
return InstanceInfo(request_id=request_id,
|
|
prefiller_idx=prefiller_idx,
|
|
prefiller_score=prefiller_score,
|
|
prefiller=prefiller,
|
|
decoder=decoder,
|
|
decoder_idx=decoder_idx,
|
|
decoder_score=decoder_score)
|
|
|
|
|
|
@dataclass
|
|
class InstanceInfo:
|
|
request_id: str
|
|
prefiller_idx: int
|
|
prefiller_score: float
|
|
prefiller: ServerState
|
|
decoder_idx: int
|
|
decoder_score: float
|
|
decoder: ServerState
|
|
|
|
|
|
async def _handle_completions(api: str, request: Request):
|
|
try:
|
|
req_data = await request.json()
|
|
req_body = await request.body()
|
|
request_length = len(req_body)
|
|
instance_info = await _handle_select_instance(api, req_data,
|
|
request_length)
|
|
stream_flag = bool(req_data.get("stream", False))
|
|
chat_flag = "messages" in req_data
|
|
|
|
if "prompt" in req_data:
|
|
origin_prompt = req_data["prompt"]
|
|
elif chat_flag:
|
|
messages = req_data["messages"]
|
|
origin_prompt = messages[0].get("content", "")
|
|
else:
|
|
origin_prompt = ""
|
|
# refer to vLLM sampling_params: max_token default value
|
|
origin_max_tokens = req_data.get("max_tokens", 16)
|
|
|
|
async def generate_stream():
|
|
nonlocal instance_info
|
|
generated_token = ""
|
|
released_kv = False
|
|
retry_count = 0
|
|
retry = True
|
|
completion_tokens = 0
|
|
# Only one await per chunk, minimal logic in loop
|
|
try:
|
|
while retry:
|
|
retry = False
|
|
async for chunk in stream_service_response_with_retry(
|
|
instance_info.decoder.client,
|
|
api,
|
|
req_data,
|
|
request_id=instance_info.request_id,
|
|
max_retries=global_args.max_retries,
|
|
base_delay=global_args.retry_delay):
|
|
if not released_kv and chunk:
|
|
proxy_state.release_prefiller_kv(
|
|
instance_info.prefiller_idx,
|
|
instance_info.prefiller_score)
|
|
released_kv = True
|
|
try:
|
|
chunk_str = chunk.decode("utf-8").strip()
|
|
except UnicodeDecodeError:
|
|
logger.debug(
|
|
f"Skipping chunk: {chunk}")
|
|
yield chunk
|
|
continue
|
|
if not chunk_str:
|
|
continue
|
|
if chunk_str.startswith("data: "):
|
|
chunk_str = chunk_str[len("data: "):]
|
|
try:
|
|
chunk_json = json.loads(chunk_str)
|
|
except json.JSONDecodeError:
|
|
# if chunk is [done], skip it.
|
|
logger.debug(
|
|
f"Skipping chunk: {chunk_str}")
|
|
yield chunk
|
|
continue
|
|
choices = chunk_json.get("choices", [])
|
|
if not choices:
|
|
yield chunk
|
|
continue
|
|
|
|
choice = choices[0]
|
|
delta = choice.get("delta") or {}
|
|
message = choice.get("message") or {}
|
|
content = (
|
|
delta.get("content")
|
|
or message.get("content")
|
|
or choice.get("text")
|
|
or ""
|
|
)
|
|
generated_token += content
|
|
|
|
stop_reason = choice.get(
|
|
"stop_reason")
|
|
usage = chunk_json.get("usage", {})
|
|
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
|
(completion_tokens + usage.get("completion_tokens"))
|
|
if stop_reason == "recomputed":
|
|
retry = True
|
|
retry_count += 1
|
|
if chat_flag:
|
|
messages[0][
|
|
"content"] = origin_prompt + generated_token
|
|
else:
|
|
req_data[
|
|
"prompt"] = origin_prompt + generated_token
|
|
req_data[
|
|
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
|
tmp_request_length = len(
|
|
json.dumps(req_data).encode("utf-8"))
|
|
instance_info = await _handle_select_instance(
|
|
api, req_data, tmp_request_length)
|
|
break
|
|
if retry_count > 0 and not stream_flag:
|
|
if chat_flag:
|
|
choice["message"][
|
|
"content"] = generated_token
|
|
else:
|
|
choice["text"] = generated_token
|
|
chunk = json.dumps(chunk_json).encode("utf-8")
|
|
yield chunk
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
|
)
|
|
proxy_state.abort_prefiller_request(
|
|
instance_info.prefiller_idx, instance_info.request_id)
|
|
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
|
instance_info.prefiller_score)
|
|
|
|
# After streaming done, release tokens
|
|
proxy_state.release_decoder(instance_info.decoder_idx,
|
|
instance_info.decoder_score)
|
|
|
|
return StreamingResponse(generate_stream(),
|
|
media_type="application/json")
|
|
except Exception as e:
|
|
import traceback
|
|
exc_info = sys.exc_info()
|
|
print("Error occurred in disagg prefill proxy server"
|
|
f" - {api} endpoint")
|
|
print(e)
|
|
print("".join(traceback.format_exception(*exc_info)))
|
|
raise
|
|
|
|
|
|
async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
|
try:
|
|
req_data = await request.json()
|
|
instance_type = req_data.get("type", "")
|
|
instances = req_data.get("instances", [])
|
|
if isinstance(instances, str):
|
|
instances = [instances]
|
|
instances = trans_instances(instances)
|
|
all_msg = f"{adjust_mode} {instance_type} instances: " \
|
|
f"{[str(server) for server in instances]}."
|
|
|
|
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
|
|
return {"error": f"Instance type {instance_type} is not supported. "
|
|
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
|
|
|
|
if adjust_mode == "add":
|
|
added_nodes, waiting_nodes = await proxy_state.add_instances(
|
|
instance_type, instances
|
|
)
|
|
if waiting_nodes:
|
|
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
|
|
f"Instances {waiting_nodes} are waiting to be added."
|
|
elif adjust_mode == "remove":
|
|
if instance_type == InstanceType.PREFILL:
|
|
proxy_state.remove_prefillers(instances)
|
|
else:
|
|
proxy_state.remove_decoders(instances)
|
|
return {
|
|
"message": all_msg,
|
|
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
|
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Failed to {adjust_mode} instances: {e}")
|
|
raise e
|
|
|
|
|
|
def trans_instances(instances: List[str]) -> List[ServerState]:
|
|
server_list = []
|
|
for instance in instances:
|
|
h, p = instance.split(":")
|
|
server_list.append(ServerState(h, int(p)))
|
|
return server_list
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
@with_cancellation
|
|
async def handle_completions(request: Request):
|
|
return await _handle_completions("/completions", request)
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@with_cancellation
|
|
async def handle_chat_completions(request: Request):
|
|
return await _handle_completions("/chat/completions", request)
|
|
|
|
|
|
@app.get("/healthcheck")
|
|
async def healthcheck():
|
|
return {
|
|
"status": "ok",
|
|
"prefill_instances": len(proxy_state.prefillers),
|
|
"decode_instances": len(proxy_state.decoders)
|
|
}
|
|
|
|
|
|
@app.post("/instances/add")
|
|
async def handle_add_instances(request: Request):
|
|
return await _handle_adjust_instances("add", request)
|
|
|
|
|
|
@app.post("/instances/remove")
|
|
async def handle_remove_instances(request: Request):
|
|
return await _handle_adjust_instances("remove", request)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
global global_args
|
|
global_args = parse_args()
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|