[feat] proxy support elastic scaling (#5063)
**[RFC]: Elastic Scaling Support for P/D Instances Based on KV Pool:**
https://github.com/vllm-project/vllm-ascend/issues/3380
### What this PR does / why we need it?
Support elastic scaling for P/D instances based on mooncake conncetor
deplayment.
**Support API routes**
* `/instances/add`: add prefill nodes or decode nodes to the list.
* `/instances/remove`: remove prefill nodes or decode nodes from the
list.
**Support functions**
* Support **adding** prefill nodes or decode nodes.
- If prefill or decode server deployed **after the proxy deployed**,
server can use `/instances/add` API to join the proxy server. The
prefill server or decode server sends a signal to the proxy server, and
the proxy server will check the status of the node util the node is
available.
* Support **removing** prefill nodes or decode nodes:
- Support using `/instances/remove` API to **delete the node** from the
proxy server.
### Does this PR introduce _any_ user-facing change?
For
`examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`:
**Add 2 params**
When adding nodes to the proxy, the proxy will wait the nodes to be
started util retrying a certain of times.
| name | type | default | help |
| ----- | ---- | ---- | ---- |
| max-waiting-retries | int | 3 | Maximum number of retries for waiting
nodes to be started |
| waiting-retry-interval | float | 10 | Check interval (seconds) for
waiting nodes to be started |
For example:
```shell
python load_balance_proxy_server_example.py \
--host 0.0.0.0 --port 9000 \
--prefiller-hosts 127.0.0.1 127.0.0.1 \
--prefiller-ports 8100 8101 \
--decoder-hosts 127.0.0.1 127.0.0.1 \
--decoder-ports 8200 8201 \
--max-waiting-retries 3 \
--waiting-retry-interval 10
```
**Add 2 API routings**
* Add instances: `instances/add`
For example, add 2 prefiller instances:
```shell
curl -X POST http://localhost:9000/instances/add \
-H "Content-Type: application/json" \
-d '{
"type": "prefill",
"instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
}'
```
Response:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102', '127.0.0.1:8103'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102', '127.0.0.1:8103'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
If the node '127.0.0.1:8103' has not benn started:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102']. Instances ['127.0.0.1:8103'] are waiting to be added.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
* Remove instances: `instances/remove`
For example, remove 1 decoder instance:
```shell
curl -X POST http://localhost:9000/instances/remove \
-H "Content-Type: application/json" \
-d '{
"type": "decode",
"instances": "127.0.0.1:8201"
}'
```
Response:
```shell
{"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']}
```
### How was this patch tested?
Run proxy and using `/instances/add` API to add nodes and
`/instances/remove` API to remove nodes
* vLLM version: v0.11.0.rc3
* vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0.rc3
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
@@ -77,6 +77,34 @@
|
||||
# This will return a JSON object with the status and the number of prefiller
|
||||
# and decoder instances.
|
||||
#
|
||||
# Step 5: Add or Remove Prefiller or Decoder Instances (Optional)
|
||||
# ---------------------------------------------------------------
|
||||
# You can add or remove prefiller or decoder instances after the proxy is started.
|
||||
# For example, add 2 prefiller instances:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/instances/add \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "type": "prefill",
|
||||
# "instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
|
||||
# }'
|
||||
#
|
||||
# or remove 1 decoder instance:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/instances/remove \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "type": "decode",
|
||||
# "instances": "127.0.0.1:8201"
|
||||
# }'
|
||||
#
|
||||
# This will return a JSON object with the adding or removing info
|
||||
# and the current prefiller and decoder instances.
|
||||
#
|
||||
# When adding instances, if the instances are not started,
|
||||
# the proxy will wait and try until the instances to be started
|
||||
# or exceeding the number of attempts
|
||||
#
|
||||
# Notes:
|
||||
# - You can scale the number of prefiller and decoder servers as needed.
|
||||
# - The proxy will round-robin requests to balance load.
|
||||
@@ -92,10 +120,12 @@ import ipaddress
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Tuple, Dict
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
@@ -113,6 +143,12 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceType:
|
||||
PREFILL: str = "prefill"
|
||||
DECODE: str = "decode"
|
||||
|
||||
|
||||
class ServerState:
|
||||
|
||||
def __init__(self, host, port):
|
||||
@@ -136,10 +172,24 @@ class ServerState:
|
||||
self.aborted_requests = set() # Track aborted requests
|
||||
# Removed individual server lock - will use global locks instead
|
||||
|
||||
def __eq__(self, other):
|
||||
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
||||
other_host = other.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
||||
return self_host == other_host and str(self.port) == str(other.port)
|
||||
|
||||
def __hash__(self):
|
||||
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
||||
return hash((self_host, str(self.port)))
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.host}:{self.port}"
|
||||
|
||||
|
||||
class ProxyState:
|
||||
|
||||
def __init__(self, prefiller_instances, decoder_instances):
|
||||
self.node_listener = NodeListener(self)
|
||||
|
||||
self.prefillers: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in prefiller_instances
|
||||
]
|
||||
@@ -264,10 +314,127 @@ class ProxyState:
|
||||
def calculate_decode_scores(self, request_length: int) -> float:
|
||||
return request_length
|
||||
|
||||
async def add_instances(
|
||||
self, instance_type: str, instances: List[ServerState]
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
added_nodes, waiting_nodes = [], []
|
||||
for server in instances:
|
||||
is_valid = await self.node_listener.check_instance_status(server.client)
|
||||
if is_valid and instance_type == InstanceType.PREFILL:
|
||||
self.add_prefillers([server])
|
||||
added_nodes.append(str(server))
|
||||
elif is_valid and instance_type == InstanceType.DECODE:
|
||||
self.add_decoders([server])
|
||||
added_nodes.append(str(server))
|
||||
else:
|
||||
node = str(server)
|
||||
self.node_listener.waiting_nodes[node] = (instance_type, server, 0)
|
||||
waiting_nodes.append(node)
|
||||
return added_nodes, waiting_nodes
|
||||
|
||||
def add_prefillers(self, instances: List[ServerState]) -> None:
|
||||
num_prefillers = len(self.prefillers)
|
||||
for idx, server in enumerate(instances):
|
||||
if server not in self.prefillers:
|
||||
self.prefillers.append(server)
|
||||
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
||||
self.print_status(f"Add prefiller instances: {instances}.")
|
||||
|
||||
def add_decoders(self, instances: List[ServerState]) -> None:
|
||||
num_decoders = len(self.decoders)
|
||||
for idx, server in enumerate(instances):
|
||||
if server not in self.decoders:
|
||||
self.decoders.append(server)
|
||||
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
||||
self.print_status(f"Add decoder instances: {instances}.")
|
||||
|
||||
def remove_prefillers(self, instances: List[ServerState]) -> None:
|
||||
instances_to_remove = set(instances)
|
||||
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
||||
prefiller_heap_copy = self.prefiller_heap.copy()
|
||||
prefiller_heap_copy.sort(key=lambda x: x[1]) # sorted by key: prefiller_idx
|
||||
prefiller_heap = []
|
||||
idx = 0
|
||||
for priority, _, server in prefiller_heap_copy:
|
||||
if server not in instances_to_remove:
|
||||
prefiller_heap.append((priority, idx, server))
|
||||
idx += 1
|
||||
|
||||
# prefiller_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
||||
self.prefiller_heap = prefiller_heap
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
self.print_status(f"Remove prefiller instances: {instances}.")
|
||||
|
||||
def remove_decoders(self, instances: List[ServerState]) -> None:
|
||||
instances_to_remove = set(instances)
|
||||
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
||||
decoder_heap_copy = self.decoder_heap.copy()
|
||||
decoder_heap_copy.sort(key=lambda x: x[1]) # sorted by key: decoder_idx
|
||||
decoder_heap = []
|
||||
idx = 0
|
||||
for priority, _, server in decoder_heap_copy:
|
||||
if server not in instances_to_remove:
|
||||
decoder_heap.append((priority, idx, server))
|
||||
idx += 1
|
||||
|
||||
# decoder_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
||||
self.decoder_heap = decoder_heap
|
||||
heapq.heapify(self.decoder_heap)
|
||||
self.print_status(f"Remove decoder instances: {instances}.")
|
||||
|
||||
def print_status(self, msg: str) -> None:
|
||||
status = {
|
||||
"prefill_instances": [str(server) for server in self.prefillers],
|
||||
"decode_instances": [str(server) for server in self.decoders]
|
||||
}
|
||||
print(f"{msg} Status: {status}")
|
||||
|
||||
|
||||
proxy_state = None
|
||||
|
||||
|
||||
class NodeListener:
|
||||
def __init__(self, proxy):
|
||||
self.proxy_state = proxy
|
||||
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
|
||||
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
|
||||
self.listening_thread.start()
|
||||
|
||||
def _node_listener(self) -> None:
|
||||
while True:
|
||||
for node, (instance_type, server, check_times) in list(self.waiting_nodes.items()):
|
||||
is_valid = asyncio.run(self.check_instance_status(server.client))
|
||||
print(f"Checking instance {node}...")
|
||||
check_times += 1
|
||||
if is_valid:
|
||||
if instance_type == InstanceType.PREFILL:
|
||||
self.proxy_state.add_prefillers([server])
|
||||
else:
|
||||
self.proxy_state.add_decoders([server])
|
||||
self.waiting_nodes.pop(node)
|
||||
elif check_times == global_args.max_waiting_retries:
|
||||
print(f"Instance {node} was not added to the proxy.")
|
||||
self.waiting_nodes.pop(node)
|
||||
else:
|
||||
self.waiting_nodes[node] = (instance_type, server, check_times)
|
||||
time.sleep(global_args.waiting_retry_interval)
|
||||
|
||||
@staticmethod
|
||||
async def check_instance_status(client: httpx.AsyncClient) -> bool:
|
||||
endpoint = "/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
try:
|
||||
response = await client.get(endpoint, headers=headers)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except (httpx.RequestError, httpx.HTTPStatusError):
|
||||
return False
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
@@ -294,6 +461,15 @@ def parse_args():
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Base delay (seconds) for exponential backoff retries")
|
||||
parser.add_argument("--max-waiting-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for waiting nodes to be started")
|
||||
parser.add_argument(
|
||||
"--waiting-retry-interval",
|
||||
type=float,
|
||||
default=10,
|
||||
help="Check interval (seconds) for waiting nodes to be started")
|
||||
args = parser.parse_args()
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
@@ -637,6 +813,51 @@ async def _handle_completions(api: str, request: Request):
|
||||
raise
|
||||
|
||||
|
||||
async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
instance_type = req_data.get("type", "")
|
||||
instances = req_data.get("instances", [])
|
||||
if isinstance(instances, str):
|
||||
instances = [instances]
|
||||
instances = trans_instances(instances)
|
||||
all_msg = f"{adjust_mode} {instance_type} instances: " \
|
||||
f"{[str(server) for server in instances]}."
|
||||
|
||||
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
|
||||
return {"error": f"Instance type {instance_type} is not supported. "
|
||||
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
|
||||
|
||||
if adjust_mode == "add":
|
||||
added_nodes, waiting_nodes = await proxy_state.add_instances(
|
||||
instance_type, instances
|
||||
)
|
||||
if waiting_nodes:
|
||||
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
|
||||
f"Instances {waiting_nodes} are waiting to be added."
|
||||
elif adjust_mode == "remove":
|
||||
if instance_type == InstanceType.PREFILL:
|
||||
proxy_state.remove_prefillers(instances)
|
||||
else:
|
||||
proxy_state.remove_decoders(instances)
|
||||
return {
|
||||
"message": all_msg,
|
||||
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
||||
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to {adjust_mode} instances: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def trans_instances(instances: List[str]) -> List[ServerState]:
|
||||
server_list = []
|
||||
for instance in instances:
|
||||
h, p = instance.split(":")
|
||||
server_list.append(ServerState(h, int(p)))
|
||||
return server_list
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@with_cancellation
|
||||
async def handle_completions(request: Request):
|
||||
@@ -658,6 +879,16 @@ async def healthcheck():
|
||||
}
|
||||
|
||||
|
||||
@app.post("/instances/add")
|
||||
async def handle_add_instances(request: Request):
|
||||
return await _handle_adjust_instances("add", request)
|
||||
|
||||
|
||||
@app.post("/instances/remove")
|
||||
async def handle_remove_instances(request: Request):
|
||||
return await _handle_adjust_instances("remove", request)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
|
||||
Reference in New Issue
Block a user