[feat] proxy support elastic scaling (#5063)
**[RFC]: Elastic Scaling Support for P/D Instances Based on KV Pool:**
https://github.com/vllm-project/vllm-ascend/issues/3380
### What this PR does / why we need it?
Support elastic scaling for P/D instances based on mooncake conncetor
deplayment.
**Support API routes**
* `/instances/add`: add prefill nodes or decode nodes to the list.
* `/instances/remove`: remove prefill nodes or decode nodes from the
list.
**Support functions**
* Support **adding** prefill nodes or decode nodes.
- If prefill or decode server deployed **after the proxy deployed**,
server can use `/instances/add` API to join the proxy server. The
prefill server or decode server sends a signal to the proxy server, and
the proxy server will check the status of the node util the node is
available.
* Support **removing** prefill nodes or decode nodes:
- Support using `/instances/remove` API to **delete the node** from the
proxy server.
### Does this PR introduce _any_ user-facing change?
For
`examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`:
**Add 2 params**
When adding nodes to the proxy, the proxy will wait the nodes to be
started util retrying a certain of times.
| name | type | default | help |
| ----- | ---- | ---- | ---- |
| max-waiting-retries | int | 3 | Maximum number of retries for waiting
nodes to be started |
| waiting-retry-interval | float | 10 | Check interval (seconds) for
waiting nodes to be started |
For example:
```shell
python load_balance_proxy_server_example.py \
--host 0.0.0.0 --port 9000 \
--prefiller-hosts 127.0.0.1 127.0.0.1 \
--prefiller-ports 8100 8101 \
--decoder-hosts 127.0.0.1 127.0.0.1 \
--decoder-ports 8200 8201 \
--max-waiting-retries 3 \
--waiting-retry-interval 10
```
**Add 2 API routings**
* Add instances: `instances/add`
For example, add 2 prefiller instances:
```shell
curl -X POST http://localhost:9000/instances/add \
-H "Content-Type: application/json" \
-d '{
"type": "prefill",
"instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
}'
```
Response:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102', '127.0.0.1:8103'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102', '127.0.0.1:8103'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
If the node '127.0.0.1:8103' has not benn started:
```shell
{"message": "add prefill instances: ['127.0.0.1:8102']. Instances ['127.0.0.1:8103'] are waiting to be added.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101', '127.0.0.1:8102'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
* Remove instances: `instances/remove`
For example, remove 1 decoder instance:
```shell
curl -X POST http://localhost:9000/instances/remove \
-H "Content-Type: application/json" \
-d '{
"type": "decode",
"instances": "127.0.0.1:8201"
}'
```
Response:
```shell
{"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']}
```
### How was this patch tested?
Run proxy and using `/instances/add` API to add nodes and
`/instances/remove` API to remove nodes
* vLLM version: v0.11.0.rc3
* vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0.rc3
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: yuxinshan <syx_ctyg@126.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
This commit is contained in:
@@ -77,6 +77,34 @@
|
|||||||
# This will return a JSON object with the status and the number of prefiller
|
# This will return a JSON object with the status and the number of prefiller
|
||||||
# and decoder instances.
|
# and decoder instances.
|
||||||
#
|
#
|
||||||
|
# Step 5: Add or Remove Prefiller or Decoder Instances (Optional)
|
||||||
|
# ---------------------------------------------------------------
|
||||||
|
# You can add or remove prefiller or decoder instances after the proxy is started.
|
||||||
|
# For example, add 2 prefiller instances:
|
||||||
|
#
|
||||||
|
# curl -X POST http://localhost:9000/instances/add \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -d '{
|
||||||
|
# "type": "prefill",
|
||||||
|
# "instances": ["127.0.0.1:8102", "127.0.0.1:8103"]
|
||||||
|
# }'
|
||||||
|
#
|
||||||
|
# or remove 1 decoder instance:
|
||||||
|
#
|
||||||
|
# curl -X POST http://localhost:9000/instances/remove \
|
||||||
|
# -H "Content-Type: application/json" \
|
||||||
|
# -d '{
|
||||||
|
# "type": "decode",
|
||||||
|
# "instances": "127.0.0.1:8201"
|
||||||
|
# }'
|
||||||
|
#
|
||||||
|
# This will return a JSON object with the adding or removing info
|
||||||
|
# and the current prefiller and decoder instances.
|
||||||
|
#
|
||||||
|
# When adding instances, if the instances are not started,
|
||||||
|
# the proxy will wait and try until the instances to be started
|
||||||
|
# or exceeding the number of attempts
|
||||||
|
#
|
||||||
# Notes:
|
# Notes:
|
||||||
# - You can scale the number of prefiller and decoder servers as needed.
|
# - You can scale the number of prefiller and decoder servers as needed.
|
||||||
# - The proxy will round-robin requests to balance load.
|
# - The proxy will round-robin requests to balance load.
|
||||||
@@ -92,10 +120,12 @@ import ipaddress
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List
|
from typing import Any, List, Tuple, Dict
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
@@ -113,6 +143,12 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InstanceType:
|
||||||
|
PREFILL: str = "prefill"
|
||||||
|
DECODE: str = "decode"
|
||||||
|
|
||||||
|
|
||||||
class ServerState:
|
class ServerState:
|
||||||
|
|
||||||
def __init__(self, host, port):
|
def __init__(self, host, port):
|
||||||
@@ -136,10 +172,24 @@ class ServerState:
|
|||||||
self.aborted_requests = set() # Track aborted requests
|
self.aborted_requests = set() # Track aborted requests
|
||||||
# Removed individual server lock - will use global locks instead
|
# Removed individual server lock - will use global locks instead
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
||||||
|
other_host = other.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
||||||
|
return self_host == other_host and str(self.port) == str(other.port)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0")
|
||||||
|
return hash((self_host, str(self.port)))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.host}:{self.port}"
|
||||||
|
|
||||||
|
|
||||||
class ProxyState:
|
class ProxyState:
|
||||||
|
|
||||||
def __init__(self, prefiller_instances, decoder_instances):
|
def __init__(self, prefiller_instances, decoder_instances):
|
||||||
|
self.node_listener = NodeListener(self)
|
||||||
|
|
||||||
self.prefillers: List[ServerState] = [
|
self.prefillers: List[ServerState] = [
|
||||||
ServerState(h, p) for h, p in prefiller_instances
|
ServerState(h, p) for h, p in prefiller_instances
|
||||||
]
|
]
|
||||||
@@ -264,10 +314,127 @@ class ProxyState:
|
|||||||
def calculate_decode_scores(self, request_length: int) -> float:
|
def calculate_decode_scores(self, request_length: int) -> float:
|
||||||
return request_length
|
return request_length
|
||||||
|
|
||||||
|
async def add_instances(
|
||||||
|
self, instance_type: str, instances: List[ServerState]
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
added_nodes, waiting_nodes = [], []
|
||||||
|
for server in instances:
|
||||||
|
is_valid = await self.node_listener.check_instance_status(server.client)
|
||||||
|
if is_valid and instance_type == InstanceType.PREFILL:
|
||||||
|
self.add_prefillers([server])
|
||||||
|
added_nodes.append(str(server))
|
||||||
|
elif is_valid and instance_type == InstanceType.DECODE:
|
||||||
|
self.add_decoders([server])
|
||||||
|
added_nodes.append(str(server))
|
||||||
|
else:
|
||||||
|
node = str(server)
|
||||||
|
self.node_listener.waiting_nodes[node] = (instance_type, server, 0)
|
||||||
|
waiting_nodes.append(node)
|
||||||
|
return added_nodes, waiting_nodes
|
||||||
|
|
||||||
|
def add_prefillers(self, instances: List[ServerState]) -> None:
|
||||||
|
num_prefillers = len(self.prefillers)
|
||||||
|
for idx, server in enumerate(instances):
|
||||||
|
if server not in self.prefillers:
|
||||||
|
self.prefillers.append(server)
|
||||||
|
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||||
|
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
||||||
|
self.print_status(f"Add prefiller instances: {instances}.")
|
||||||
|
|
||||||
|
def add_decoders(self, instances: List[ServerState]) -> None:
|
||||||
|
num_decoders = len(self.decoders)
|
||||||
|
for idx, server in enumerate(instances):
|
||||||
|
if server not in self.decoders:
|
||||||
|
self.decoders.append(server)
|
||||||
|
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||||
|
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
||||||
|
self.print_status(f"Add decoder instances: {instances}.")
|
||||||
|
|
||||||
|
def remove_prefillers(self, instances: List[ServerState]) -> None:
|
||||||
|
instances_to_remove = set(instances)
|
||||||
|
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
||||||
|
prefiller_heap_copy = self.prefiller_heap.copy()
|
||||||
|
prefiller_heap_copy.sort(key=lambda x: x[1]) # sorted by key: prefiller_idx
|
||||||
|
prefiller_heap = []
|
||||||
|
idx = 0
|
||||||
|
for priority, _, server in prefiller_heap_copy:
|
||||||
|
if server not in instances_to_remove:
|
||||||
|
prefiller_heap.append((priority, idx, server))
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# prefiller_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
||||||
|
self.prefiller_heap = prefiller_heap
|
||||||
|
heapq.heapify(self.prefiller_heap)
|
||||||
|
self.print_status(f"Remove prefiller instances: {instances}.")
|
||||||
|
|
||||||
|
def remove_decoders(self, instances: List[ServerState]) -> None:
|
||||||
|
instances_to_remove = set(instances)
|
||||||
|
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
||||||
|
decoder_heap_copy = self.decoder_heap.copy()
|
||||||
|
decoder_heap_copy.sort(key=lambda x: x[1]) # sorted by key: decoder_idx
|
||||||
|
decoder_heap = []
|
||||||
|
idx = 0
|
||||||
|
for priority, _, server in decoder_heap_copy:
|
||||||
|
if server not in instances_to_remove:
|
||||||
|
decoder_heap.append((priority, idx, server))
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
# decoder_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)]
|
||||||
|
self.decoder_heap = decoder_heap
|
||||||
|
heapq.heapify(self.decoder_heap)
|
||||||
|
self.print_status(f"Remove decoder instances: {instances}.")
|
||||||
|
|
||||||
|
def print_status(self, msg: str) -> None:
|
||||||
|
status = {
|
||||||
|
"prefill_instances": [str(server) for server in self.prefillers],
|
||||||
|
"decode_instances": [str(server) for server in self.decoders]
|
||||||
|
}
|
||||||
|
print(f"{msg} Status: {status}")
|
||||||
|
|
||||||
|
|
||||||
proxy_state = None
|
proxy_state = None
|
||||||
|
|
||||||
|
|
||||||
|
class NodeListener:
|
||||||
|
def __init__(self, proxy):
|
||||||
|
self.proxy_state = proxy
|
||||||
|
self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {}
|
||||||
|
self.listening_thread = threading.Thread(target=self._node_listener, daemon=True)
|
||||||
|
self.listening_thread.start()
|
||||||
|
|
||||||
|
def _node_listener(self) -> None:
|
||||||
|
while True:
|
||||||
|
for node, (instance_type, server, check_times) in list(self.waiting_nodes.items()):
|
||||||
|
is_valid = asyncio.run(self.check_instance_status(server.client))
|
||||||
|
print(f"Checking instance {node}...")
|
||||||
|
check_times += 1
|
||||||
|
if is_valid:
|
||||||
|
if instance_type == InstanceType.PREFILL:
|
||||||
|
self.proxy_state.add_prefillers([server])
|
||||||
|
else:
|
||||||
|
self.proxy_state.add_decoders([server])
|
||||||
|
self.waiting_nodes.pop(node)
|
||||||
|
elif check_times == global_args.max_waiting_retries:
|
||||||
|
print(f"Instance {node} was not added to the proxy.")
|
||||||
|
self.waiting_nodes.pop(node)
|
||||||
|
else:
|
||||||
|
self.waiting_nodes[node] = (instance_type, server, check_times)
|
||||||
|
time.sleep(global_args.waiting_retry_interval)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def check_instance_status(client: httpx.AsyncClient) -> bool:
|
||||||
|
endpoint = "/models"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = await client.get(endpoint, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
return True
|
||||||
|
except (httpx.RequestError, httpx.HTTPStatusError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
@@ -294,6 +461,15 @@ def parse_args():
|
|||||||
type=float,
|
type=float,
|
||||||
default=0.001,
|
default=0.001,
|
||||||
help="Base delay (seconds) for exponential backoff retries")
|
help="Base delay (seconds) for exponential backoff retries")
|
||||||
|
parser.add_argument("--max-waiting-retries",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Maximum number of retries for waiting nodes to be started")
|
||||||
|
parser.add_argument(
|
||||||
|
"--waiting-retry-interval",
|
||||||
|
type=float,
|
||||||
|
default=10,
|
||||||
|
help="Check interval (seconds) for waiting nodes to be started")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -637,6 +813,51 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
||||||
|
try:
|
||||||
|
req_data = await request.json()
|
||||||
|
instance_type = req_data.get("type", "")
|
||||||
|
instances = req_data.get("instances", [])
|
||||||
|
if isinstance(instances, str):
|
||||||
|
instances = [instances]
|
||||||
|
instances = trans_instances(instances)
|
||||||
|
all_msg = f"{adjust_mode} {instance_type} instances: " \
|
||||||
|
f"{[str(server) for server in instances]}."
|
||||||
|
|
||||||
|
if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]:
|
||||||
|
return {"error": f"Instance type {instance_type} is not supported. "
|
||||||
|
f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."}
|
||||||
|
|
||||||
|
if adjust_mode == "add":
|
||||||
|
added_nodes, waiting_nodes = await proxy_state.add_instances(
|
||||||
|
instance_type, instances
|
||||||
|
)
|
||||||
|
if waiting_nodes:
|
||||||
|
all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \
|
||||||
|
f"Instances {waiting_nodes} are waiting to be added."
|
||||||
|
elif adjust_mode == "remove":
|
||||||
|
if instance_type == InstanceType.PREFILL:
|
||||||
|
proxy_state.remove_prefillers(instances)
|
||||||
|
else:
|
||||||
|
proxy_state.remove_decoders(instances)
|
||||||
|
return {
|
||||||
|
"message": all_msg,
|
||||||
|
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
||||||
|
"current_decode_instances": [str(decoder) for decoder in proxy_state.decoders]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to {adjust_mode} instances: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def trans_instances(instances: List[str]) -> List[ServerState]:
|
||||||
|
server_list = []
|
||||||
|
for instance in instances:
|
||||||
|
h, p = instance.split(":")
|
||||||
|
server_list.append(ServerState(h, int(p)))
|
||||||
|
return server_list
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
@with_cancellation
|
@with_cancellation
|
||||||
async def handle_completions(request: Request):
|
async def handle_completions(request: Request):
|
||||||
@@ -658,6 +879,16 @@ async def healthcheck():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/instances/add")
|
||||||
|
async def handle_add_instances(request: Request):
|
||||||
|
return await _handle_adjust_instances("add", request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/instances/remove")
|
||||||
|
async def handle_remove_instances(request: Request):
|
||||||
|
return await _handle_adjust_instances("remove", request)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
global global_args
|
global global_args
|
||||||
global_args = parse_args()
|
global_args = parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user