[Feat] proxy delay to remove instances (#5934)

### What this PR does / why we need it?
For the proxy, we should remove instances when the proxy are not
processing requests.
But sometimes, We need to **isolate** some faulty nodes when a large
number of **requests** are coming in.
So we support to **isolate** faulty nodes by **lowering their priority**
and **deleted** them when the proxy does not process requests.

### Does this PR introduce _any_ user-facing change?
For
`examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`,
when using `/instances/remove` API to delete the node from the proxy
server:
```txt
curl -X POST http://localhost:9000/instances/remove \
  -H "Content-Type: application/json" \
  -d '{
        "type": "decode",
        "instances": "127.0.0.1:8201"
      }'
```
There are 2 situations:
* 【New】When the proxy is processing requests, isolate the nodes and
remove them when the proxy is free.
```txt
{"message": "Instances ['127.0.0.1:8201'] are isolated and waiting to be removed.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
* When the proxy is free, remove the nodes directly.
```txt
{"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']}
```
### How was this patch tested?


- vLLM version: v0.13.0
- vLLM main:
11b6af5280

Signed-off-by: yuxinshan <syx_ctyg@126.com>
This commit is contained in:
yuxinshan
2026-01-26 16:29:45 +08:00
committed by GitHub
parent de095c5fed
commit 7d119df2a9

View File

@@ -130,9 +130,15 @@ from typing import Any
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
from vllm.logger import init_logger
logger = init_logger(__name__)
except ImportError:
import logging
logger = logging.getLogger(__name__)
# Add uvloop for faster event loop if available
try:
@@ -149,6 +155,9 @@ class InstanceType:
DECODE: str = "decode"
TAINT_PRIORITY = 1e15
class ServerState:
def __init__(self, host, port):
self.host = host
@@ -186,6 +195,9 @@ class ServerState:
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.request_num = 0
self.tainted_prefillers: list[ServerState] = []
self.tainted_decoders: list[ServerState] = []
self.node_listener = NodeListener(self)
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
@@ -225,6 +237,8 @@ class ProxyState:
prefiller node.
"""
# No lock needed - atomic operation
if server_idx >= len(self.prefillers):
return
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
@@ -233,6 +247,8 @@ class ProxyState:
This is used to release kv cache in prefiller node.
"""
# No lock needed - atomic operation
if server_idx >= len(self.prefillers):
return set()
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
self.prefillers[server_idx].aborted_requests.clear()
return aborted_requests
@@ -259,12 +275,16 @@ class ProxyState:
def release_prefiller(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if idx >= len(self.prefillers):
return
self.prefillers[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_prefiller_priority(idx)
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if idx >= len(self.prefillers):
return
if self.prefillers[idx].active_kv_cache > 0:
self.prefillers[idx].active_kv_cache -= token_count
# Update priority queue after releasing
@@ -287,6 +307,8 @@ class ProxyState:
def release_decoder(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if idx >= len(self.decoders):
return
self.decoders[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_decoder_priority(idx)
@@ -317,24 +339,44 @@ class ProxyState:
return added_nodes, waiting_nodes
def add_prefillers(self, instances: list[ServerState]) -> None:
num_prefillers = len(self.prefillers)
for idx, server in enumerate(instances):
if server not in self.prefillers:
for server in instances:
if server in self.tainted_prefillers:
self.tainted_prefillers.remove(server)
self.prefiller_heap = [
(0, idx, server) if srv == server else (priority, idx, srv)
for priority, idx, srv in self.prefiller_heap
]
heapq.heapify(self.prefiller_heap)
elif server not in self.prefillers:
self.prefillers.append(server)
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
heapq.heappush(self.prefiller_heap, (0, len(self.prefillers) - 1, server))
self.print_status(f"Add prefiller instances: {instances}.")
def add_decoders(self, instances: list[ServerState]) -> None:
num_decoders = len(self.decoders)
for idx, server in enumerate(instances):
if server not in self.decoders:
for server in instances:
if server in self.tainted_decoders:
self.tainted_decoders.remove(server)
self.decoder_heap = [
(0, idx, server) if srv == server else (priority, idx, srv)
for priority, idx, srv in self.decoder_heap
]
heapq.heapify(self.decoder_heap)
elif server not in self.decoders:
self.decoders.append(server)
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
heapq.heappush(self.decoder_heap, (0, len(self.decoders) - 1, server))
self.print_status(f"Add decoder instances: {instances}.")
def remove_prefillers(self, instances: list[ServerState]) -> None:
def remove_prefillers(self, instances: list[ServerState]) -> bool:
if not instances:
return False
if self.request_num > 0:
logger.warning(f"Start to taint prefill instances {instances}.")
self._taint_prefillers(instances)
return True
instances_to_remove = set(instances)
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
prefiller_heap_copy = self.prefiller_heap.copy()
@@ -350,8 +392,17 @@ class ProxyState:
self.prefiller_heap = prefiller_heap
heapq.heapify(self.prefiller_heap)
self.print_status(f"Remove prefiller instances: {instances}.")
return False
def remove_decoders(self, instances: list[ServerState]) -> bool:
if not instances:
return False
if self.request_num > 0:
logger.warning(f"Start to taint decode instances {instances}.")
self._taint_decoders(instances)
return True
def remove_decoders(self, instances: list[ServerState]) -> None:
instances_to_remove = set(instances)
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
decoder_heap_copy = self.decoder_heap.copy()
@@ -367,6 +418,31 @@ class ProxyState:
self.decoder_heap = decoder_heap
heapq.heapify(self.decoder_heap)
self.print_status(f"Remove decoder instances: {instances}.")
return False
def _taint_prefillers(self, instances: list[ServerState]) -> None:
instances_to_taint = set(instances)
for server in self.prefillers:
if server in instances_to_taint and server not in self.tainted_prefillers:
self.tainted_prefillers.append(server)
self.prefiller_heap = [
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
for priority, idx, srv in self.prefiller_heap
]
heapq.heapify(self.prefiller_heap)
def _taint_decoders(self, instances: list[ServerState]) -> None:
instances_to_taint = set(instances)
for server in self.decoders:
if server in instances_to_taint and server not in self.tainted_decoders:
self.tainted_decoders.append(server)
self.decoder_heap = [
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
for priority, idx, srv in self.decoder_heap
]
heapq.heapify(self.decoder_heap)
def print_status(self, msg: str) -> None:
status = {
@@ -403,6 +479,16 @@ class NodeListener:
self.waiting_nodes.pop(node)
else:
self.waiting_nodes[node] = (instance_type, server, check_times)
if self.proxy_state.tainted_prefillers and not self.proxy_state.request_num:
need_waiting = self.proxy_state.remove_prefillers(self.proxy_state.tainted_prefillers)
if not need_waiting:
self.proxy_state.tainted_prefillers.clear()
if self.proxy_state.tainted_decoders and not self.proxy_state.request_num:
need_waiting = self.proxy_state.remove_decoders(self.proxy_state.tainted_decoders)
if not need_waiting:
self.proxy_state.tainted_decoders.clear()
time.sleep(global_args.waiting_retry_interval)
@staticmethod
@@ -623,6 +709,7 @@ class InstanceInfo:
async def _handle_completions(api: str, request: Request):
try:
proxy_state.request_num += 1
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
@@ -736,6 +823,8 @@ async def _handle_completions(api: str, request: Request):
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
finally:
proxy_state.request_num -= 1
async def _handle_adjust_instances(adjust_mode: str, request: Request):
@@ -763,9 +852,12 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
)
elif adjust_mode == "remove":
if instance_type == InstanceType.PREFILL:
proxy_state.remove_prefillers(instances)
need_waiting = proxy_state.remove_prefillers(instances)
else:
proxy_state.remove_decoders(instances)
need_waiting = proxy_state.remove_decoders(instances)
if need_waiting:
all_msg = f"Instances {instances} are isolated and waiting to be removed."
return {
"message": all_msg,
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],