From 7d119df2a9f4b4314db37287b4536982f1c7dee5 Mon Sep 17 00:00:00 2001 From: yuxinshan <82206277+yuxinshan@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:29:45 +0800 Subject: [PATCH] [Feat] proxy delay to remove instances (#5934) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? For the proxy, we should remove instances when the proxy are not processing requests. But sometimes, We need to **isolate** some faulty nodes when a large number of **requests** are coming in. So we support to **isolate** faulty nodes by **lowering their priority** and **deleted** them when the proxy does not process requests. ### Does this PR introduce _any_ user-facing change? For `examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`, when using `/instances/remove` API to delete the node from the proxy server: ```txt curl -X POST http://localhost:9000/instances/remove \ -H "Content-Type: application/json" \ -d '{ "type": "decode", "instances": "127.0.0.1:8201" }' ``` There are 2 situations: * 【New】When the proxy is processing requests, isolate the nodes and remove them when the proxy is free. ```txt {"message": "Instances ['127.0.0.1:8201'] are isolated and waiting to be removed.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']} ``` * When the proxy is free, remove the nodes directly. ```txt {"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']} ``` ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/11b6af5280d6d6dfb8953af16e67b25f819b3be9 Signed-off-by: yuxinshan --- .../load_balance_proxy_server_example.py | 120 ++++++++++++++++-- 1 file changed, 106 insertions(+), 14 deletions(-) diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index 96cb7bac..0af2c962 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -130,9 +130,15 @@ from typing import Any import httpx from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse -from vllm.logger import init_logger -logger = init_logger(__name__) +try: + from vllm.logger import init_logger + + logger = init_logger(__name__) +except ImportError: + import logging + + logger = logging.getLogger(__name__) # Add uvloop for faster event loop if available try: @@ -149,6 +155,9 @@ class InstanceType: DECODE: str = "decode" +TAINT_PRIORITY = 1e15 + + class ServerState: def __init__(self, host, port): self.host = host @@ -186,6 +195,9 @@ class ServerState: class ProxyState: def __init__(self, prefiller_instances, decoder_instances): + self.request_num = 0 + self.tainted_prefillers: list[ServerState] = [] + self.tainted_decoders: list[ServerState] = [] self.node_listener = NodeListener(self) self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances] @@ -225,6 +237,8 @@ class ProxyState: prefiller node. """ # No lock needed - atomic operation + if server_idx >= len(self.prefillers): + return self.prefillers[server_idx].aborted_requests.add(request_id) def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous @@ -233,6 +247,8 @@ class ProxyState: This is used to release kv cache in prefiller node. """ # No lock needed - atomic operation + if server_idx >= len(self.prefillers): + return set() aborted_requests = self.prefillers[server_idx].aborted_requests.copy() self.prefillers[server_idx].aborted_requests.clear() return aborted_requests @@ -259,12 +275,16 @@ class ProxyState: def release_prefiller(self, idx, token_count): # Changed to synchronous # No lock needed - atomic operation + if idx >= len(self.prefillers): + return self.prefillers[idx].active_tokens -= token_count # Update priority queue after releasing self._update_prefiller_priority(idx) def release_prefiller_kv(self, idx, token_count): # Changed to synchronous # No lock needed - atomic operation + if idx >= len(self.prefillers): + return if self.prefillers[idx].active_kv_cache > 0: self.prefillers[idx].active_kv_cache -= token_count # Update priority queue after releasing @@ -287,6 +307,8 @@ class ProxyState: def release_decoder(self, idx, token_count): # Changed to synchronous # No lock needed - atomic operation + if idx >= len(self.decoders): + return self.decoders[idx].active_tokens -= token_count # Update priority queue after releasing self._update_decoder_priority(idx) @@ -317,24 +339,44 @@ class ProxyState: return added_nodes, waiting_nodes def add_prefillers(self, instances: list[ServerState]) -> None: - num_prefillers = len(self.prefillers) - for idx, server in enumerate(instances): - if server not in self.prefillers: + for server in instances: + if server in self.tainted_prefillers: + self.tainted_prefillers.remove(server) + self.prefiller_heap = [ + (0, idx, server) if srv == server else (priority, idx, srv) + for priority, idx, srv in self.prefiller_heap + ] + heapq.heapify(self.prefiller_heap) + elif server not in self.prefillers: self.prefillers.append(server) # prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)] - heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server)) + heapq.heappush(self.prefiller_heap, (0, len(self.prefillers) - 1, server)) self.print_status(f"Add prefiller instances: {instances}.") def add_decoders(self, instances: list[ServerState]) -> None: - num_decoders = len(self.decoders) - for idx, server in enumerate(instances): - if server not in self.decoders: + for server in instances: + if server in self.tainted_decoders: + self.tainted_decoders.remove(server) + self.decoder_heap = [ + (0, idx, server) if srv == server else (priority, idx, srv) + for priority, idx, srv in self.decoder_heap + ] + heapq.heapify(self.decoder_heap) + elif server not in self.decoders: self.decoders.append(server) # decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)] - heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server)) + heapq.heappush(self.decoder_heap, (0, len(self.decoders) - 1, server)) self.print_status(f"Add decoder instances: {instances}.") - def remove_prefillers(self, instances: list[ServerState]) -> None: + def remove_prefillers(self, instances: list[ServerState]) -> bool: + if not instances: + return False + + if self.request_num > 0: + logger.warning(f"Start to taint prefill instances {instances}.") + self._taint_prefillers(instances) + return True + instances_to_remove = set(instances) self.prefillers = [server for server in self.prefillers if server not in instances_to_remove] prefiller_heap_copy = self.prefiller_heap.copy() @@ -350,8 +392,17 @@ class ProxyState: self.prefiller_heap = prefiller_heap heapq.heapify(self.prefiller_heap) self.print_status(f"Remove prefiller instances: {instances}.") + return False + + def remove_decoders(self, instances: list[ServerState]) -> bool: + if not instances: + return False + + if self.request_num > 0: + logger.warning(f"Start to taint decode instances {instances}.") + self._taint_decoders(instances) + return True - def remove_decoders(self, instances: list[ServerState]) -> None: instances_to_remove = set(instances) self.decoders = [server for server in self.decoders if server not in instances_to_remove] decoder_heap_copy = self.decoder_heap.copy() @@ -367,6 +418,31 @@ class ProxyState: self.decoder_heap = decoder_heap heapq.heapify(self.decoder_heap) self.print_status(f"Remove decoder instances: {instances}.") + return False + + def _taint_prefillers(self, instances: list[ServerState]) -> None: + instances_to_taint = set(instances) + for server in self.prefillers: + if server in instances_to_taint and server not in self.tainted_prefillers: + self.tainted_prefillers.append(server) + + self.prefiller_heap = [ + (TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv) + for priority, idx, srv in self.prefiller_heap + ] + heapq.heapify(self.prefiller_heap) + + def _taint_decoders(self, instances: list[ServerState]) -> None: + instances_to_taint = set(instances) + for server in self.decoders: + if server in instances_to_taint and server not in self.tainted_decoders: + self.tainted_decoders.append(server) + + self.decoder_heap = [ + (TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv) + for priority, idx, srv in self.decoder_heap + ] + heapq.heapify(self.decoder_heap) def print_status(self, msg: str) -> None: status = { @@ -403,6 +479,16 @@ class NodeListener: self.waiting_nodes.pop(node) else: self.waiting_nodes[node] = (instance_type, server, check_times) + + if self.proxy_state.tainted_prefillers and not self.proxy_state.request_num: + need_waiting = self.proxy_state.remove_prefillers(self.proxy_state.tainted_prefillers) + if not need_waiting: + self.proxy_state.tainted_prefillers.clear() + + if self.proxy_state.tainted_decoders and not self.proxy_state.request_num: + need_waiting = self.proxy_state.remove_decoders(self.proxy_state.tainted_decoders) + if not need_waiting: + self.proxy_state.tainted_decoders.clear() time.sleep(global_args.waiting_retry_interval) @staticmethod @@ -623,6 +709,7 @@ class InstanceInfo: async def _handle_completions(api: str, request: Request): try: + proxy_state.request_num += 1 req_data = await request.json() req_body = await request.body() request_length = len(req_body) @@ -736,6 +823,8 @@ async def _handle_completions(api: str, request: Request): print(e) print("".join(traceback.format_exception(*exc_info))) raise + finally: + proxy_state.request_num -= 1 async def _handle_adjust_instances(adjust_mode: str, request: Request): @@ -763,9 +852,12 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request): ) elif adjust_mode == "remove": if instance_type == InstanceType.PREFILL: - proxy_state.remove_prefillers(instances) + need_waiting = proxy_state.remove_prefillers(instances) else: - proxy_state.remove_decoders(instances) + need_waiting = proxy_state.remove_decoders(instances) + + if need_waiting: + all_msg = f"Instances {instances} are isolated and waiting to be removed." return { "message": all_msg, "current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],