[feat] proxy support elastic scaling (#5063)

**[RFC]: Elastic Scaling Support for P/D Instances Based on KV Pool:**
https://github.com/vllm-project/vllm-ascend/issues/3380

### What this PR does / why we need it?
Support elastic scaling for P/D instances based on mooncake conncetor
deplayment.

**Support API routes**
* `/instances/add`: add prefill nodes or decode nodes to the list.
* `/instances/remove`: remove prefill nodes or decode nodes from the
list.

**Support functions**
* Support **adding** prefill nodes or decode nodes.
- If prefill or decode server deployed **after the proxy deployed**,
server can use `/instances/add` API to join the proxy server. The
prefill server or decode server sends a signal to the proxy server, and
the proxy server will check the status of the node util the node is
available.
* Support **removing** prefill nodes or decode nodes:
- Support using `/instances/remove` API to **delete the node** from the
proxy server.

### Does this PR introduce _any_ user-facing change?
For
`examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`:

**Add 2 params**

When adding nodes to the proxy, the proxy will wait the nodes to be
started util retrying a certain of times.

| name | type | default | help |
| ----- | ---- | ---- | ---- |
| max-waiting-retries | int | 3 | Maximum number of retries for waiting
nodes to be started |
| waiting-retry-interval | float | 10 | Check interval (seconds) for
waiting nodes to be started |

For example:
```shell
python load_balance_proxy_server_example.py \
  --host 0.0.0.0 --port 9000 \
  --prefiller-hosts 127.0.0.1 127.0.0.1 \
  --prefiller-ports 8100 8101 \
  --decoder-hosts 127.0.0.1 127.0.0.1 \
  --decoder-ports 8200 8201 \
  --max-waiting-retries 3 \
  --waiting-retry-interval 10
```
**Add 2 API routings**

* Add instances: `instances/add`

For example, add 2 prefiller instances:
```shell
curl -X POST http://localhost:9000/instances/add \
  -H "Content-Type: application/json" \
  -d '{
        "type": "prefill",
        "instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
      }'
```
Response:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102', '127.0.0.1:8103'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102', '127.0.0.1:8103'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
If the node '127.0.0.1:8103' has not benn started:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102']. Instances ['127.0.0.1:8103'] are waiting to be added.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
* Remove instances: `instances/remove`

For example, remove 1 decoder instance:
```shell
curl -X POST http://localhost:9000/instances/remove \
  -H "Content-Type: application/json" \
  -d '{
        "type": "decode",
        "instances": "127.0.0.1:8201"
      }'
```
Response:
```shell
{"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']}
```
### How was this patch tested?
Run proxy and using `/instances/add` API to add nodes and
`/instances/remove` API to remove nodes

* vLLM version: v0.11.0.rc3
* vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0.rc3
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
yuxinshan
2025-12-18 14:29:53 +08:00
committed by GitHub
parent 71e544e259
commit b0376abd4c

View File

@@ -77,6 +77,34 @@
# This will return a JSON object with the status and the number of prefiller
# and decoder instances.
#
# Step 5: Add or Remove Prefiller or Decoder Instances (Optional)
# ---------------------------------------------------------------
# You can add or remove prefiller or decoder instances after the proxy is started.
# For example, add 2 prefiller instances:
#
# curl -X POST http://localhost:9000/instances/add \
# -H "Content-Type: application/json" \
# -d '{
# "type": "prefill",
# "instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
# }'
#
# or remove 1 decoder instance:
#
# curl -X POST http://localhost:9000/instances/remove \
# -H "Content-Type: application/json" \
# -d '{
# "type": "decode",
# "instances": "127.0.0.1:8201"
# }'
#
# This will return a JSON object with the adding or removing info
# and the current prefiller and decoder instances.
#
# When adding instances, if the instances are not started,
# the proxy will wait and try until the instances to be started
# or exceeding the number of attempts
#
# Notes:
# - You can scale the number of prefiller and decoder servers as needed.
# - The proxy will round-robin requests to balance load.
@@ -92,10 +120,12 @@ import ipaddress
import json
import os
import sys
import threading
import time
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List
from typing import Any, List, Tuple, Dict
import httpx
from fastapi import FastAPI, Request
@@ -113,6 +143,12 @@ except ImportError:
pass
@dataclass
class InstanceType:
PREFILL: str = "prefill"
DECODE: str = "decode"
class ServerState:
def __init__(self, host, port):
@@ -136,10 +172,24 @@ class ServerState:
self.aborted_requests = set() # Track aborted requests
# Removed individual server lock - will use global locks instead
def __eq__(self, other):
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
other_host = other.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
return self_host == other_host and str(self.port) == str(other.port)
def __hash__(self):
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
return hash((self_host, str(self.port)))
def __repr__(self):
return f"{self.host}:{self.port}"
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.node_listener = NodeListener(self)
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
@@ -264,10 +314,127 @@ class ProxyState:
def calculate_decode_scores(self, request_length: int) -> float:
return request_length
async def add_instances(
self, instance_type: str, instances: List[ServerState]
) -> Tuple[List[str], List[str]]:
added_nodes, waiting_nodes = [], []
for server in instances:
is_valid = await self.node_listener.check_instance_status(server.client)
if is_valid and instance_type == InstanceType.PREFILL:
self.add_prefillers([server])
added_nodes.append(str(server))
elif is_valid and instance_type == InstanceType.DECODE:
self.add_decoders([server])
added_nodes.append(str(server))
else:
node = str(server)
self.node_listener.waiting_nodes[node] = (instance_type, server, 0)
waiting_nodes.append(node)
return added_nodes, waiting_nodes
def add_prefillers(self, instances: List[ServerState]) -> None:
num_prefillers = len(self.prefillers)
for idx, server in enumerate(instances):
if server not in self.prefillers:
self.prefillers.append(server)
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
self.print_status(f"Add prefiller instances: {instances}.")
def add_decoders(self, instances: List[ServerState]) -> None:
num_decoders = len(self.decoders)
for idx, server in enumerate(instances):
if server not in self.decoders:
self.decoders.append(server)
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
self.print_status(f"Add decoder instances: {instances}.")
def remove_prefillers(self, instances: List[ServerState]) -> None:
instances_to_remove = set(instances)
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
prefiller_heap_copy = self.prefiller_heap.copy()
prefiller_heap_copy.sort(key=lambda x: x[1]) # sorted by key: prefiller_idx
prefiller_heap = []
idx = 0
for priority, _, server in prefiller_heap_copy:
if server not in instances_to_remove:
prefiller_heap.append((priority, idx, server))
idx += 1
# prefiller_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
self.prefiller_heap = prefiller_heap
heapq.heapify(self.prefiller_heap)
self.print_status(f"Remove prefiller instances: {instances}.")
def remove_decoders(self, instances: List[ServerState]) -> None:
instances_to_remove = set(instances)
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
decoder_heap_copy = self.decoder_heap.copy()
decoder_heap_copy.sort(key=lambda x: x[1]) # sorted by key: decoder_idx
decoder_heap = []
idx = 0
for priority, _, server in decoder_heap_copy:
if server not in instances_to_remove:
decoder_heap.append((priority, idx, server))
idx += 1
# decoder_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
self.decoder_heap = decoder_heap
heapq.heapify(self.decoder_heap)
self.print_status(f"Remove decoder instances: {instances}.")
def print_status(self, msg: str) -> None:
status = {
"prefill_instances": [str(server) for server in self.prefillers],
"decode_instances": [str(server) for server in self.decoders]
}
print(f"{msg} Status: {status}")
proxy_state = None
class NodeListener:
def __init__(self, proxy):
self.proxy_state = proxy
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
self.listening_thread.start()
def _node_listener(self) -> None:
while True:
for node, (instance_type, server, check_times) in list(self.waiting_nodes.items()):
is_valid = asyncio.run(self.check_instance_status(server.client))
print(f"Checking instance {node}...")
check_times += 1
if is_valid:
if instance_type == InstanceType.PREFILL:
self.proxy_state.add_prefillers([server])
else:
self.proxy_state.add_decoders([server])
self.waiting_nodes.pop(node)
elif check_times == global_args.max_waiting_retries:
print(f"Instance {node} was not added to the proxy.")
self.waiting_nodes.pop(node)
else:
self.waiting_nodes[node] = (instance_type, server, check_times)
time.sleep(global_args.waiting_retry_interval)
@staticmethod
async def check_instance_status(client: httpx.AsyncClient) -> bool:
endpoint = "/models"
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
try:
response = await client.get(endpoint, headers=headers)
response.raise_for_status()
return True
except (httpx.RequestError, httpx.HTTPStatusError):
return False
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
@@ -294,6 +461,15 @@ def parse_args():
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
parser.add_argument("--max-waiting-retries",
type=int,
default=3,
help="Maximum number of retries for waiting nodes to be started")
parser.add_argument(
"--waiting-retry-interval",
type=float,
default=10,
help="Check interval (seconds) for waiting nodes to be started")
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
@@ -637,6 +813,51 @@ async def _handle_completions(api: str, request: Request):
raise
async def _handle_adjust_instances(adjust_mode: str, request: Request):
try:
req_data = await request.json()
instance_type = req_data.get("type", "")
instances = req_data.get("instances", [])
if isinstance(instances, str):
instances = [instances]
instances = trans_instances(instances)
all_msg = f"{adjust_mode} {instance_type} instances: " \
f"{[str(server) for server in instances]}."
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
return {"error": f"Instance type {instance_type} is not supported. "
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
if adjust_mode == "add":
added_nodes, waiting_nodes = await proxy_state.add_instances(
instance_type, instances
)
if waiting_nodes:
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
f"Instances {waiting_nodes} are waiting to be added."
elif adjust_mode == "remove":
if instance_type == InstanceType.PREFILL:
proxy_state.remove_prefillers(instances)
else:
proxy_state.remove_decoders(instances)
return {
"message": all_msg,
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
}
except Exception as e:
logger.error(f"Failed to {adjust_mode} instances: {e}")
raise e
def trans_instances(instances: List[str]) -> List[ServerState]:
server_list = []
for instance in instances:
h, p = instance.split(":")
server_list.append(ServerState(h, int(p)))
return server_list
@app.post("/v1/completions")
@with_cancellation
async def handle_completions(request: Request):
@@ -658,6 +879,16 @@ async def healthcheck():
}
@app.post("/instances/add")
async def handle_add_instances(request: Request):
return await _handle_adjust_instances("add", request)
@app.post("/instances/remove")
async def handle_remove_instances(request: Request):
return await _handle_adjust_instances("remove", request)
if __name__ == '__main__':
global global_args
global_args = parse_args()