[Feat] proxy delay to remove instances (#5934)
### What this PR does / why we need it?
For the proxy, we should remove instances when the proxy are not
processing requests.
But sometimes, We need to **isolate** some faulty nodes when a large
number of **requests** are coming in.
So we support to **isolate** faulty nodes by **lowering their priority**
and **deleted** them when the proxy does not process requests.
### Does this PR introduce _any_ user-facing change?
For
`examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`,
when using `/instances/remove` API to delete the node from the proxy
server:
```txt
curl -X POST http://localhost:9000/instances/remove \
-H "Content-Type: application/json" \
-d '{
"type": "decode",
"instances": "127.0.0.1:8201"
}'
```
There are 2 situations:
* 【New】When the proxy is processing requests, isolate the nodes and
remove them when the proxy is free.
```txt
{"message": "Instances ['127.0.0.1:8201'] are isolated and waiting to be removed.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
* When the proxy is free, remove the nodes directly.
```txt
{"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']}
```
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: yuxinshan <syx_ctyg@126.com>
This commit is contained in:
@@ -130,9 +130,15 @@ from typing import Any
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
try:
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add uvloop for faster event loop if available
|
||||
try:
|
||||
@@ -149,6 +155,9 @@ class InstanceType:
|
||||
DECODE: str = "decode"
|
||||
|
||||
|
||||
TAINT_PRIORITY = 1e15
|
||||
|
||||
|
||||
class ServerState:
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
@@ -186,6 +195,9 @@ class ServerState:
|
||||
|
||||
class ProxyState:
|
||||
def __init__(self, prefiller_instances, decoder_instances):
|
||||
self.request_num = 0
|
||||
self.tainted_prefillers: list[ServerState] = []
|
||||
self.tainted_decoders: list[ServerState] = []
|
||||
self.node_listener = NodeListener(self)
|
||||
|
||||
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
|
||||
@@ -225,6 +237,8 @@ class ProxyState:
|
||||
prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
if server_idx >= len(self.prefillers):
|
||||
return
|
||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||
|
||||
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
||||
@@ -233,6 +247,8 @@ class ProxyState:
|
||||
This is used to release kv cache in prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
if server_idx >= len(self.prefillers):
|
||||
return set()
|
||||
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
||||
self.prefillers[server_idx].aborted_requests.clear()
|
||||
return aborted_requests
|
||||
@@ -259,12 +275,16 @@ class ProxyState:
|
||||
|
||||
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
if idx >= len(self.prefillers):
|
||||
return
|
||||
self.prefillers[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_prefiller_priority(idx)
|
||||
|
||||
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
if idx >= len(self.prefillers):
|
||||
return
|
||||
if self.prefillers[idx].active_kv_cache > 0:
|
||||
self.prefillers[idx].active_kv_cache -= token_count
|
||||
# Update priority queue after releasing
|
||||
@@ -287,6 +307,8 @@ class ProxyState:
|
||||
|
||||
def release_decoder(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
if idx >= len(self.decoders):
|
||||
return
|
||||
self.decoders[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_decoder_priority(idx)
|
||||
@@ -317,24 +339,44 @@ class ProxyState:
|
||||
return added_nodes, waiting_nodes
|
||||
|
||||
def add_prefillers(self, instances: list[ServerState]) -> None:
|
||||
num_prefillers = len(self.prefillers)
|
||||
for idx, server in enumerate(instances):
|
||||
if server not in self.prefillers:
|
||||
for server in instances:
|
||||
if server in self.tainted_prefillers:
|
||||
self.tainted_prefillers.remove(server)
|
||||
self.prefiller_heap = [
|
||||
(0, idx, server) if srv == server else (priority, idx, srv)
|
||||
for priority, idx, srv in self.prefiller_heap
|
||||
]
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
elif server not in self.prefillers:
|
||||
self.prefillers.append(server)
|
||||
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
||||
heapq.heappush(self.prefiller_heap, (0, len(self.prefillers) - 1, server))
|
||||
self.print_status(f"Add prefiller instances: {instances}.")
|
||||
|
||||
def add_decoders(self, instances: list[ServerState]) -> None:
|
||||
num_decoders = len(self.decoders)
|
||||
for idx, server in enumerate(instances):
|
||||
if server not in self.decoders:
|
||||
for server in instances:
|
||||
if server in self.tainted_decoders:
|
||||
self.tainted_decoders.remove(server)
|
||||
self.decoder_heap = [
|
||||
(0, idx, server) if srv == server else (priority, idx, srv)
|
||||
for priority, idx, srv in self.decoder_heap
|
||||
]
|
||||
heapq.heapify(self.decoder_heap)
|
||||
elif server not in self.decoders:
|
||||
self.decoders.append(server)
|
||||
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
||||
heapq.heappush(self.decoder_heap, (0, len(self.decoders) - 1, server))
|
||||
self.print_status(f"Add decoder instances: {instances}.")
|
||||
|
||||
def remove_prefillers(self, instances: list[ServerState]) -> None:
|
||||
def remove_prefillers(self, instances: list[ServerState]) -> bool:
|
||||
if not instances:
|
||||
return False
|
||||
|
||||
if self.request_num > 0:
|
||||
logger.warning(f"Start to taint prefill instances {instances}.")
|
||||
self._taint_prefillers(instances)
|
||||
return True
|
||||
|
||||
instances_to_remove = set(instances)
|
||||
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
||||
prefiller_heap_copy = self.prefiller_heap.copy()
|
||||
@@ -350,8 +392,17 @@ class ProxyState:
|
||||
self.prefiller_heap = prefiller_heap
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
self.print_status(f"Remove prefiller instances: {instances}.")
|
||||
return False
|
||||
|
||||
def remove_decoders(self, instances: list[ServerState]) -> bool:
|
||||
if not instances:
|
||||
return False
|
||||
|
||||
if self.request_num > 0:
|
||||
logger.warning(f"Start to taint decode instances {instances}.")
|
||||
self._taint_decoders(instances)
|
||||
return True
|
||||
|
||||
def remove_decoders(self, instances: list[ServerState]) -> None:
|
||||
instances_to_remove = set(instances)
|
||||
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
||||
decoder_heap_copy = self.decoder_heap.copy()
|
||||
@@ -367,6 +418,31 @@ class ProxyState:
|
||||
self.decoder_heap = decoder_heap
|
||||
heapq.heapify(self.decoder_heap)
|
||||
self.print_status(f"Remove decoder instances: {instances}.")
|
||||
return False
|
||||
|
||||
def _taint_prefillers(self, instances: list[ServerState]) -> None:
|
||||
instances_to_taint = set(instances)
|
||||
for server in self.prefillers:
|
||||
if server in instances_to_taint and server not in self.tainted_prefillers:
|
||||
self.tainted_prefillers.append(server)
|
||||
|
||||
self.prefiller_heap = [
|
||||
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
|
||||
for priority, idx, srv in self.prefiller_heap
|
||||
]
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
|
||||
def _taint_decoders(self, instances: list[ServerState]) -> None:
|
||||
instances_to_taint = set(instances)
|
||||
for server in self.decoders:
|
||||
if server in instances_to_taint and server not in self.tainted_decoders:
|
||||
self.tainted_decoders.append(server)
|
||||
|
||||
self.decoder_heap = [
|
||||
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
|
||||
for priority, idx, srv in self.decoder_heap
|
||||
]
|
||||
heapq.heapify(self.decoder_heap)
|
||||
|
||||
def print_status(self, msg: str) -> None:
|
||||
status = {
|
||||
@@ -403,6 +479,16 @@ class NodeListener:
|
||||
self.waiting_nodes.pop(node)
|
||||
else:
|
||||
self.waiting_nodes[node] = (instance_type, server, check_times)
|
||||
|
||||
if self.proxy_state.tainted_prefillers and not self.proxy_state.request_num:
|
||||
need_waiting = self.proxy_state.remove_prefillers(self.proxy_state.tainted_prefillers)
|
||||
if not need_waiting:
|
||||
self.proxy_state.tainted_prefillers.clear()
|
||||
|
||||
if self.proxy_state.tainted_decoders and not self.proxy_state.request_num:
|
||||
need_waiting = self.proxy_state.remove_decoders(self.proxy_state.tainted_decoders)
|
||||
if not need_waiting:
|
||||
self.proxy_state.tainted_decoders.clear()
|
||||
time.sleep(global_args.waiting_retry_interval)
|
||||
|
||||
@staticmethod
|
||||
@@ -623,6 +709,7 @@ class InstanceInfo:
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
proxy_state.request_num += 1
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
@@ -736,6 +823,8 @@ async def _handle_completions(api: str, request: Request):
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
finally:
|
||||
proxy_state.request_num -= 1
|
||||
|
||||
|
||||
async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
||||
@@ -763,9 +852,12 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
||||
)
|
||||
elif adjust_mode == "remove":
|
||||
if instance_type == InstanceType.PREFILL:
|
||||
proxy_state.remove_prefillers(instances)
|
||||
need_waiting = proxy_state.remove_prefillers(instances)
|
||||
else:
|
||||
proxy_state.remove_decoders(instances)
|
||||
need_waiting = proxy_state.remove_decoders(instances)
|
||||
|
||||
if need_waiting:
|
||||
all_msg = f"Instances {instances} are isolated and waiting to be removed."
|
||||
return {
|
||||
"message": all_msg,
|
||||
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
||||
|
||||
Reference in New Issue
Block a user