diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index 68659974..5e7f6278 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -77,6 +77,34 @@ # This will return a JSON object with the status and the number of prefiller # and decoder instances. # +# Step 5: Add or Remove Prefiller or Decoder Instances (Optional) +# --------------------------------------------------------------- +# You can add or remove prefiller or decoder instances after the proxy is started. +# For example, add 2 prefiller instances: +# +# curl -X POST http://localhost:9000/instances/add \ +# -H "Content-Type: application/json" \ +# -d '{ +# "type": "prefill", +# "instances": ["127.0.0.1:8102", "127.0.0.1:8103"] +# }' +# +# or remove 1 decoder instance: +# +# curl -X POST http://localhost:9000/instances/remove \ +# -H "Content-Type: application/json" \ +# -d '{ +# "type": "decode", +# "instances": "127.0.0.1:8201" +# }' +# +# This will return a JSON object with the adding or removing info +# and the current prefiller and decoder instances. +# +# When adding instances, if the instances are not started, +# the proxy will wait and try until the instances to be started +# or exceeding the number of attempts +# # Notes: # - You can scale the number of prefiller and decoder servers as needed. # - The proxy will round-robin requests to balance load. @@ -92,10 +120,12 @@ import ipaddress import json import os import sys +import threading +import time import uuid from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, List +from typing import Any, List, Tuple, Dict import httpx from fastapi import FastAPI, Request @@ -113,6 +143,12 @@ except ImportError: pass +@dataclass +class InstanceType: + PREFILL: str = "prefill" + DECODE: str = "decode" + + class ServerState: def __init__(self, host, port): @@ -136,10 +172,24 @@ class ServerState: self.aborted_requests = set() # Track aborted requests # Removed individual server lock - will use global locks instead + def __eq__(self, other): + self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0") + other_host = other.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0") + return self_host == other_host and str(self.port) == str(other.port) + + def __hash__(self): + self_host = self.host.replace("localhost", "0.0.0.0").replace("127.0.0.1", "0.0.0.0") + return hash((self_host, str(self.port))) + + def __repr__(self): + return f"{self.host}:{self.port}" + class ProxyState: def __init__(self, prefiller_instances, decoder_instances): + self.node_listener = NodeListener(self) + self.prefillers: List[ServerState] = [ ServerState(h, p) for h, p in prefiller_instances ] @@ -264,10 +314,127 @@ class ProxyState: def calculate_decode_scores(self, request_length: int) -> float: return request_length + async def add_instances( + self, instance_type: str, instances: List[ServerState] + ) -> Tuple[List[str], List[str]]: + added_nodes, waiting_nodes = [], [] + for server in instances: + is_valid = await self.node_listener.check_instance_status(server.client) + if is_valid and instance_type == InstanceType.PREFILL: + self.add_prefillers([server]) + added_nodes.append(str(server)) + elif is_valid and instance_type == InstanceType.DECODE: + self.add_decoders([server]) + added_nodes.append(str(server)) + else: + node = str(server) + self.node_listener.waiting_nodes[node] = (instance_type, server, 0) + waiting_nodes.append(node) + return added_nodes, waiting_nodes + + def add_prefillers(self, instances: List[ServerState]) -> None: + num_prefillers = len(self.prefillers) + for idx, server in enumerate(instances): + if server not in self.prefillers: + self.prefillers.append(server) + # prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)] + heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server)) + self.print_status(f"Add prefiller instances: {instances}.") + + def add_decoders(self, instances: List[ServerState]) -> None: + num_decoders = len(self.decoders) + for idx, server in enumerate(instances): + if server not in self.decoders: + self.decoders.append(server) + # decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)] + heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server)) + self.print_status(f"Add decoder instances: {instances}.") + + def remove_prefillers(self, instances: List[ServerState]) -> None: + instances_to_remove = set(instances) + self.prefillers = [server for server in self.prefillers if server not in instances_to_remove] + prefiller_heap_copy = self.prefiller_heap.copy() + prefiller_heap_copy.sort(key=lambda x: x[1]) # sorted by key: prefiller_idx + prefiller_heap = [] + idx = 0 + for priority, _, server in prefiller_heap_copy: + if server not in instances_to_remove: + prefiller_heap.append((priority, idx, server)) + idx += 1 + + # prefiller_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)] + self.prefiller_heap = prefiller_heap + heapq.heapify(self.prefiller_heap) + self.print_status(f"Remove prefiller instances: {instances}.") + + def remove_decoders(self, instances: List[ServerState]) -> None: + instances_to_remove = set(instances) + self.decoders = [server for server in self.decoders if server not in instances_to_remove] + decoder_heap_copy = self.decoder_heap.copy() + decoder_heap_copy.sort(key=lambda x: x[1]) # sorted by key: decoder_idx + decoder_heap = [] + idx = 0 + for priority, _, server in decoder_heap_copy: + if server not in instances_to_remove: + decoder_heap.append((priority, idx, server)) + idx += 1 + + # decoder_heap: [(priority_0, 0, server_0), (priority_1, 1, server_1)] -> [(priority_1, 0, server_1)] + self.decoder_heap = decoder_heap + heapq.heapify(self.decoder_heap) + self.print_status(f"Remove decoder instances: {instances}.") + + def print_status(self, msg: str) -> None: + status = { + "prefill_instances": [str(server) for server in self.prefillers], + "decode_instances": [str(server) for server in self.decoders] + } + print(f"{msg} Status: {status}") + proxy_state = None +class NodeListener: + def __init__(self, proxy): + self.proxy_state = proxy + self.waiting_nodes: Dict[str, Tuple[str, Any, int]] = {} + self.listening_thread = threading.Thread(target=self._node_listener, daemon=True) + self.listening_thread.start() + + def _node_listener(self) -> None: + while True: + for node, (instance_type, server, check_times) in list(self.waiting_nodes.items()): + is_valid = asyncio.run(self.check_instance_status(server.client)) + print(f"Checking instance {node}...") + check_times += 1 + if is_valid: + if instance_type == InstanceType.PREFILL: + self.proxy_state.add_prefillers([server]) + else: + self.proxy_state.add_decoders([server]) + self.waiting_nodes.pop(node) + elif check_times == global_args.max_waiting_retries: + print(f"Instance {node} was not added to the proxy.") + self.waiting_nodes.pop(node) + else: + self.waiting_nodes[node] = (instance_type, server, check_times) + time.sleep(global_args.waiting_retry_interval) + + @staticmethod + async def check_instance_status(client: httpx.AsyncClient) -> bool: + endpoint = "/models" + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + try: + response = await client.get(endpoint, headers=headers) + response.raise_for_status() + return True + except (httpx.RequestError, httpx.HTTPStatusError): + return False + + def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=8000) @@ -294,6 +461,15 @@ def parse_args(): type=float, default=0.001, help="Base delay (seconds) for exponential backoff retries") + parser.add_argument("--max-waiting-retries", + type=int, + default=3, + help="Maximum number of retries for waiting nodes to be started") + parser.add_argument( + "--waiting-retry-interval", + type=float, + default=10, + help="Check interval (seconds) for waiting nodes to be started") args = parser.parse_args() if len(args.prefiller_hosts) != len(args.prefiller_ports): raise ValueError( @@ -637,6 +813,51 @@ async def _handle_completions(api: str, request: Request): raise +async def _handle_adjust_instances(adjust_mode: str, request: Request): + try: + req_data = await request.json() + instance_type = req_data.get("type", "") + instances = req_data.get("instances", []) + if isinstance(instances, str): + instances = [instances] + instances = trans_instances(instances) + all_msg = f"{adjust_mode} {instance_type} instances: " \ + f"{[str(server) for server in instances]}." + + if instance_type not in [InstanceType.PREFILL, InstanceType.DECODE]: + return {"error": f"Instance type {instance_type} is not supported. " + f"Only support '{InstanceType.PREFILL}' and '{InstanceType.DECODE}'."} + + if adjust_mode == "add": + added_nodes, waiting_nodes = await proxy_state.add_instances( + instance_type, instances + ) + if waiting_nodes: + all_msg = f"{adjust_mode} {instance_type} instances: {added_nodes}. " \ + f"Instances {waiting_nodes} are waiting to be added." + elif adjust_mode == "remove": + if instance_type == InstanceType.PREFILL: + proxy_state.remove_prefillers(instances) + else: + proxy_state.remove_decoders(instances) + return { + "message": all_msg, + "current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers], + "current_decode_instances": [str(decoder) for decoder in proxy_state.decoders] + } + except Exception as e: + logger.error(f"Failed to {adjust_mode} instances: {e}") + raise e + + +def trans_instances(instances: List[str]) -> List[ServerState]: + server_list = [] + for instance in instances: + h, p = instance.split(":") + server_list.append(ServerState(h, int(p))) + return server_list + + @app.post("/v1/completions") @with_cancellation async def handle_completions(request: Request): @@ -658,6 +879,16 @@ async def healthcheck(): } +@app.post("/instances/add") +async def handle_add_instances(request: Request): + return await _handle_adjust_instances("add", request) + + +@app.post("/instances/remove") +async def handle_remove_instances(request: Request): + return await _handle_adjust_instances("remove", request) + + if __name__ == '__main__': global global_args global_args = parse_args()