[Feat] proxy delay to remove instances (#5934)
### What this PR does / why we need it?
For the proxy, we should remove instances when the proxy are not
processing requests.
But sometimes, We need to **isolate** some faulty nodes when a large
number of **requests** are coming in.
So we support to **isolate** faulty nodes by **lowering their priority**
and **deleted** them when the proxy does not process requests.
### Does this PR introduce _any_ user-facing change?
For
`examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py`,
when using `/instances/remove` API to delete the node from the proxy
server:
```txt
curl -X POST http://localhost:9000/instances/remove \
-H "Content-Type: application/json" \
-d '{
"type": "decode",
"instances": "127.0.0.1:8201"
}'
```
There are 2 situations:
* 【New】When the proxy is processing requests, isolate the nodes and
remove them when the proxy is free.
```txt
{"message": "Instances ['127.0.0.1:8201'] are isolated and waiting to be removed.", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200', '127.0.0.1:8201']}
```
* When the proxy is free, remove the nodes directly.
```txt
{"message": "remove decode instances: ['127.0.0.1:8201'].", "current_prefill_instances": ['127.0.0.1:8100', '127.0.0.1:8101'], "current_decode_instances": ['127.0.0.1:8200']}
```
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
11b6af5280
Signed-off-by: yuxinshan <syx_ctyg@126.com>
This commit is contained in:
@@ -130,9 +130,15 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
try:
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
except ImportError:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Add uvloop for faster event loop if available
|
# Add uvloop for faster event loop if available
|
||||||
try:
|
try:
|
||||||
@@ -149,6 +155,9 @@ class InstanceType:
|
|||||||
DECODE: str = "decode"
|
DECODE: str = "decode"
|
||||||
|
|
||||||
|
|
||||||
|
TAINT_PRIORITY = 1e15
|
||||||
|
|
||||||
|
|
||||||
class ServerState:
|
class ServerState:
|
||||||
def __init__(self, host, port):
|
def __init__(self, host, port):
|
||||||
self.host = host
|
self.host = host
|
||||||
@@ -186,6 +195,9 @@ class ServerState:
|
|||||||
|
|
||||||
class ProxyState:
|
class ProxyState:
|
||||||
def __init__(self, prefiller_instances, decoder_instances):
|
def __init__(self, prefiller_instances, decoder_instances):
|
||||||
|
self.request_num = 0
|
||||||
|
self.tainted_prefillers: list[ServerState] = []
|
||||||
|
self.tainted_decoders: list[ServerState] = []
|
||||||
self.node_listener = NodeListener(self)
|
self.node_listener = NodeListener(self)
|
||||||
|
|
||||||
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
|
self.prefillers: list[ServerState] = [ServerState(h, p) for h, p in prefiller_instances]
|
||||||
@@ -225,6 +237,8 @@ class ProxyState:
|
|||||||
prefiller node.
|
prefiller node.
|
||||||
"""
|
"""
|
||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
|
if server_idx >= len(self.prefillers):
|
||||||
|
return
|
||||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||||
|
|
||||||
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
def aquire_aborted_prefiller_requests(self, server_idx: int): # Changed to synchronous
|
||||||
@@ -233,6 +247,8 @@ class ProxyState:
|
|||||||
This is used to release kv cache in prefiller node.
|
This is used to release kv cache in prefiller node.
|
||||||
"""
|
"""
|
||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
|
if server_idx >= len(self.prefillers):
|
||||||
|
return set()
|
||||||
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
||||||
self.prefillers[server_idx].aborted_requests.clear()
|
self.prefillers[server_idx].aborted_requests.clear()
|
||||||
return aborted_requests
|
return aborted_requests
|
||||||
@@ -259,12 +275,16 @@ class ProxyState:
|
|||||||
|
|
||||||
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
|
if idx >= len(self.prefillers):
|
||||||
|
return
|
||||||
self.prefillers[idx].active_tokens -= token_count
|
self.prefillers[idx].active_tokens -= token_count
|
||||||
# Update priority queue after releasing
|
# Update priority queue after releasing
|
||||||
self._update_prefiller_priority(idx)
|
self._update_prefiller_priority(idx)
|
||||||
|
|
||||||
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
|
if idx >= len(self.prefillers):
|
||||||
|
return
|
||||||
if self.prefillers[idx].active_kv_cache > 0:
|
if self.prefillers[idx].active_kv_cache > 0:
|
||||||
self.prefillers[idx].active_kv_cache -= token_count
|
self.prefillers[idx].active_kv_cache -= token_count
|
||||||
# Update priority queue after releasing
|
# Update priority queue after releasing
|
||||||
@@ -287,6 +307,8 @@ class ProxyState:
|
|||||||
|
|
||||||
def release_decoder(self, idx, token_count): # Changed to synchronous
|
def release_decoder(self, idx, token_count): # Changed to synchronous
|
||||||
# No lock needed - atomic operation
|
# No lock needed - atomic operation
|
||||||
|
if idx >= len(self.decoders):
|
||||||
|
return
|
||||||
self.decoders[idx].active_tokens -= token_count
|
self.decoders[idx].active_tokens -= token_count
|
||||||
# Update priority queue after releasing
|
# Update priority queue after releasing
|
||||||
self._update_decoder_priority(idx)
|
self._update_decoder_priority(idx)
|
||||||
@@ -317,24 +339,44 @@ class ProxyState:
|
|||||||
return added_nodes, waiting_nodes
|
return added_nodes, waiting_nodes
|
||||||
|
|
||||||
def add_prefillers(self, instances: list[ServerState]) -> None:
|
def add_prefillers(self, instances: list[ServerState]) -> None:
|
||||||
num_prefillers = len(self.prefillers)
|
for server in instances:
|
||||||
for idx, server in enumerate(instances):
|
if server in self.tainted_prefillers:
|
||||||
if server not in self.prefillers:
|
self.tainted_prefillers.remove(server)
|
||||||
|
self.prefiller_heap = [
|
||||||
|
(0, idx, server) if srv == server else (priority, idx, srv)
|
||||||
|
for priority, idx, srv in self.prefiller_heap
|
||||||
|
]
|
||||||
|
heapq.heapify(self.prefiller_heap)
|
||||||
|
elif server not in self.prefillers:
|
||||||
self.prefillers.append(server)
|
self.prefillers.append(server)
|
||||||
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
# prefiller_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||||
heapq.heappush(self.prefiller_heap, (0, num_prefillers + idx, server))
|
heapq.heappush(self.prefiller_heap, (0, len(self.prefillers) - 1, server))
|
||||||
self.print_status(f"Add prefiller instances: {instances}.")
|
self.print_status(f"Add prefiller instances: {instances}.")
|
||||||
|
|
||||||
def add_decoders(self, instances: list[ServerState]) -> None:
|
def add_decoders(self, instances: list[ServerState]) -> None:
|
||||||
num_decoders = len(self.decoders)
|
for server in instances:
|
||||||
for idx, server in enumerate(instances):
|
if server in self.tainted_decoders:
|
||||||
if server not in self.decoders:
|
self.tainted_decoders.remove(server)
|
||||||
|
self.decoder_heap = [
|
||||||
|
(0, idx, server) if srv == server else (priority, idx, srv)
|
||||||
|
for priority, idx, srv in self.decoder_heap
|
||||||
|
]
|
||||||
|
heapq.heapify(self.decoder_heap)
|
||||||
|
elif server not in self.decoders:
|
||||||
self.decoders.append(server)
|
self.decoders.append(server)
|
||||||
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
# decoder_heap: [(priority_0, 0, server_0)] -> [(priority_0, 0, server_0), (0, 1, server_1)]
|
||||||
heapq.heappush(self.decoder_heap, (0, num_decoders + idx, server))
|
heapq.heappush(self.decoder_heap, (0, len(self.decoders) - 1, server))
|
||||||
self.print_status(f"Add decoder instances: {instances}.")
|
self.print_status(f"Add decoder instances: {instances}.")
|
||||||
|
|
||||||
def remove_prefillers(self, instances: list[ServerState]) -> None:
|
def remove_prefillers(self, instances: list[ServerState]) -> bool:
|
||||||
|
if not instances:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.request_num > 0:
|
||||||
|
logger.warning(f"Start to taint prefill instances {instances}.")
|
||||||
|
self._taint_prefillers(instances)
|
||||||
|
return True
|
||||||
|
|
||||||
instances_to_remove = set(instances)
|
instances_to_remove = set(instances)
|
||||||
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
self.prefillers = [server for server in self.prefillers if server not in instances_to_remove]
|
||||||
prefiller_heap_copy = self.prefiller_heap.copy()
|
prefiller_heap_copy = self.prefiller_heap.copy()
|
||||||
@@ -350,8 +392,17 @@ class ProxyState:
|
|||||||
self.prefiller_heap = prefiller_heap
|
self.prefiller_heap = prefiller_heap
|
||||||
heapq.heapify(self.prefiller_heap)
|
heapq.heapify(self.prefiller_heap)
|
||||||
self.print_status(f"Remove prefiller instances: {instances}.")
|
self.print_status(f"Remove prefiller instances: {instances}.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def remove_decoders(self, instances: list[ServerState]) -> bool:
|
||||||
|
if not instances:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.request_num > 0:
|
||||||
|
logger.warning(f"Start to taint decode instances {instances}.")
|
||||||
|
self._taint_decoders(instances)
|
||||||
|
return True
|
||||||
|
|
||||||
def remove_decoders(self, instances: list[ServerState]) -> None:
|
|
||||||
instances_to_remove = set(instances)
|
instances_to_remove = set(instances)
|
||||||
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
self.decoders = [server for server in self.decoders if server not in instances_to_remove]
|
||||||
decoder_heap_copy = self.decoder_heap.copy()
|
decoder_heap_copy = self.decoder_heap.copy()
|
||||||
@@ -367,6 +418,31 @@ class ProxyState:
|
|||||||
self.decoder_heap = decoder_heap
|
self.decoder_heap = decoder_heap
|
||||||
heapq.heapify(self.decoder_heap)
|
heapq.heapify(self.decoder_heap)
|
||||||
self.print_status(f"Remove decoder instances: {instances}.")
|
self.print_status(f"Remove decoder instances: {instances}.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _taint_prefillers(self, instances: list[ServerState]) -> None:
|
||||||
|
instances_to_taint = set(instances)
|
||||||
|
for server in self.prefillers:
|
||||||
|
if server in instances_to_taint and server not in self.tainted_prefillers:
|
||||||
|
self.tainted_prefillers.append(server)
|
||||||
|
|
||||||
|
self.prefiller_heap = [
|
||||||
|
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
|
||||||
|
for priority, idx, srv in self.prefiller_heap
|
||||||
|
]
|
||||||
|
heapq.heapify(self.prefiller_heap)
|
||||||
|
|
||||||
|
def _taint_decoders(self, instances: list[ServerState]) -> None:
|
||||||
|
instances_to_taint = set(instances)
|
||||||
|
for server in self.decoders:
|
||||||
|
if server in instances_to_taint and server not in self.tainted_decoders:
|
||||||
|
self.tainted_decoders.append(server)
|
||||||
|
|
||||||
|
self.decoder_heap = [
|
||||||
|
(TAINT_PRIORITY, idx, srv) if srv in instances_to_taint else (priority, idx, srv)
|
||||||
|
for priority, idx, srv in self.decoder_heap
|
||||||
|
]
|
||||||
|
heapq.heapify(self.decoder_heap)
|
||||||
|
|
||||||
def print_status(self, msg: str) -> None:
|
def print_status(self, msg: str) -> None:
|
||||||
status = {
|
status = {
|
||||||
@@ -403,6 +479,16 @@ class NodeListener:
|
|||||||
self.waiting_nodes.pop(node)
|
self.waiting_nodes.pop(node)
|
||||||
else:
|
else:
|
||||||
self.waiting_nodes[node] = (instance_type, server, check_times)
|
self.waiting_nodes[node] = (instance_type, server, check_times)
|
||||||
|
|
||||||
|
if self.proxy_state.tainted_prefillers and not self.proxy_state.request_num:
|
||||||
|
need_waiting = self.proxy_state.remove_prefillers(self.proxy_state.tainted_prefillers)
|
||||||
|
if not need_waiting:
|
||||||
|
self.proxy_state.tainted_prefillers.clear()
|
||||||
|
|
||||||
|
if self.proxy_state.tainted_decoders and not self.proxy_state.request_num:
|
||||||
|
need_waiting = self.proxy_state.remove_decoders(self.proxy_state.tainted_decoders)
|
||||||
|
if not need_waiting:
|
||||||
|
self.proxy_state.tainted_decoders.clear()
|
||||||
time.sleep(global_args.waiting_retry_interval)
|
time.sleep(global_args.waiting_retry_interval)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -623,6 +709,7 @@ class InstanceInfo:
|
|||||||
|
|
||||||
async def _handle_completions(api: str, request: Request):
|
async def _handle_completions(api: str, request: Request):
|
||||||
try:
|
try:
|
||||||
|
proxy_state.request_num += 1
|
||||||
req_data = await request.json()
|
req_data = await request.json()
|
||||||
req_body = await request.body()
|
req_body = await request.body()
|
||||||
request_length = len(req_body)
|
request_length = len(req_body)
|
||||||
@@ -736,6 +823,8 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
print(e)
|
print(e)
|
||||||
print("".join(traceback.format_exception(*exc_info)))
|
print("".join(traceback.format_exception(*exc_info)))
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
proxy_state.request_num -= 1
|
||||||
|
|
||||||
|
|
||||||
async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
||||||
@@ -763,9 +852,12 @@ async def _handle_adjust_instances(adjust_mode: str, request: Request):
|
|||||||
)
|
)
|
||||||
elif adjust_mode == "remove":
|
elif adjust_mode == "remove":
|
||||||
if instance_type == InstanceType.PREFILL:
|
if instance_type == InstanceType.PREFILL:
|
||||||
proxy_state.remove_prefillers(instances)
|
need_waiting = proxy_state.remove_prefillers(instances)
|
||||||
else:
|
else:
|
||||||
proxy_state.remove_decoders(instances)
|
need_waiting = proxy_state.remove_decoders(instances)
|
||||||
|
|
||||||
|
if need_waiting:
|
||||||
|
all_msg = f"Instances {instances} are isolated and waiting to be removed."
|
||||||
return {
|
return {
|
||||||
"message": all_msg,
|
"message": all_msg,
|
||||||
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
"current_prefill_instances": [str(prefiller) for prefiller in proxy_state.prefillers],
|
||||||
|
|||||||
Reference in New Issue
Block a user