2025-08-04 10:35:53 +08:00
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
# SPDX-License-Identifier: Apache-2.0
#
# Tutorial: Using the Load Balance Proxy Server Example
#
# This proxy server is designed to distribute requests between multiple
# "prefiller" and "decoder" backend servers for large language model inference.
# It is useful for scaling out inference workloads and balancing load across
# multiple backend instances.
#
# Features:
# - Load balances requests to multiple prefiller and decoder servers.
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
# - Streams responses from backend servers to clients.
#
# Prerequisites:
# - Python 3.8+
# - Install dependencies:
2025-12-15 23:51:38 +08:00
# pip install fastapi<0.124.0 httpx uvicorn vllm
2025-08-04 10:35:53 +08:00
#
# Step 1: Start Your Backend Servers
# ----------------------------------
# You need to have at least one prefiller and one decoder backend running.
# These can be mock servers or actual vLLM servers.
#
# For testing, you can use the provided mock server:
#
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
#
# Step 2: Start the Proxy Server
# ------------------------------
# Run the proxy server, specifying the host/port for each prefiller and decoder:
#
# python load_balance_proxy_server_example.py \
# --host 0.0.0.0 --port 9000 \
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
# --prefiller-ports 8100 8101 \
# --decoder-hosts 127.0.0.1 127.0.0.1 \
# --decoder-ports 8200 8201
#
# This will start the proxy on port 9000, load balancing between two prefiller
# and two decoder servers.
#
# Step 3: Send a Request to the Proxy
# -----------------------------------
# You can now send OpenAI-compatible requests to the proxy. For example:
#
# curl -X POST http://localhost:9000/v1/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "your-model",
# "prompt": "The quick brown fox jumps over the lazy dog",
# "max_tokens": 16
# }'
#
# Or for chat completions:
#
# curl -X POST http://localhost:9000/v1/chat/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "your-model",
# "messages": [{"role": "user", "content": "Hello!"}],
# "max_tokens": 16
# }'
#
# Step 4: Health Check
# --------------------
# To check if the proxy is running and see how many backend instances are
# connected, use:
#
# curl http://localhost:9000/healthcheck
#
# This will return a JSON object with the status and the number of prefiller
# and decoder instances.
#
# Notes:
# - You can scale the number of prefiller and decoder servers as needed.
# - The proxy will round-robin requests to balance load.
# - For production, ensure your backend servers are robust and secure.
#
# For more details, see the code and comments in this file.
import argparse
import asyncio
2025-08-29 17:19:23 +08:00
import functools
2025-08-04 10:35:53 +08:00
import heapq
2025-11-18 16:33:00 +08:00
import ipaddress
2025-10-18 15:56:44 +08:00
import json
2025-08-04 10:35:53 +08:00
import os
import sys
2025-08-29 17:19:23 +08:00
import uuid
2025-08-04 10:35:53 +08:00
from contextlib import asynccontextmanager
2025-10-18 15:56:44 +08:00
from dataclasses import dataclass
from typing import Any , List
2025-08-04 10:35:53 +08:00
import httpx
from fastapi import FastAPI , Request
from fastapi . responses import StreamingResponse
from vllm . logger import init_logger
logger = init_logger ( __name__ )
# Add uvloop for faster event loop if available
try :
import uvloop
2025-10-18 15:56:44 +08:00
2025-08-04 10:35:53 +08:00
asyncio . set_event_loop_policy ( uvloop . EventLoopPolicy ( ) )
except ImportError :
pass
class ServerState :
def __init__ ( self , host , port ) :
self . host = host
self . port = port
self . url = f ' http:// { host } : { port } /v1 '
2025-11-18 16:33:00 +08:00
try :
ip = ipaddress . ip_address ( self . host )
if isinstance ( ip , ipaddress . IPv6Address ) :
self . url = f ' http://[ { host } ]: { port } /v1 '
except Exception :
pass
2025-08-04 10:35:53 +08:00
self . client = httpx . AsyncClient ( timeout = None ,
base_url = self . url ,
limits = httpx . Limits (
max_connections = 100000 ,
max_keepalive_connections = 100000 ) )
self . active_tokens = 0
self . active_kv_cache = 0 # Only for prefiller
self . active_requests = 0 # Number of active requests
self . aborted_requests = set ( ) # Track aborted requests
# Removed individual server lock - will use global locks instead
class ProxyState :
def __init__ ( self , prefiller_instances , decoder_instances ) :
self . prefillers : List [ ServerState ] = [
ServerState ( h , p ) for h , p in prefiller_instances
]
self . decoders : List [ ServerState ] = [
ServerState ( h , p ) for h , p in decoder_instances
]
self . req_to_prefiller = { }
self . req_id_lock = asyncio . Lock ( )
# Removed selection locks - no longer needed for synchronous methods
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self . prefiller_heap = [ ( 0 , i , server )
for i , server in enumerate ( self . prefillers ) ]
self . decoder_heap = [ ( 0 , i , server )
for i , server in enumerate ( self . decoders ) ]
heapq . heapify ( self . prefiller_heap )
heapq . heapify ( self . decoder_heap )
def _update_prefiller_priority ( self , server_idx : int ) :
""" Update the priority of a prefiller server in the heap. """
server = self . prefillers [ server_idx ]
# Priority based on active_tokens and active_kv_cache
priority = server . active_tokens + server . active_kv_cache * 0.3
# Remove old entry and add new one
self . prefiller_heap = [ ( p , i , s ) for p , i , s in self . prefiller_heap
if i != server_idx ]
heapq . heappush ( self . prefiller_heap ,
( priority , server_idx , server ) ) # type: ignore
def _update_decoder_priority ( self , server_idx : int ) :
""" Update the priority of a decoder server in the heap. """
server = self . decoders [ server_idx ]
priority = server . active_tokens
# Remove old entry and add new one
self . decoder_heap = [ ( p , i , s ) for p , i , s in self . decoder_heap
if i != server_idx ]
heapq . heappush ( self . decoder_heap ,
( priority , server_idx , server ) ) # type: ignore
def abort_prefiller_request ( self , server_idx : int ,
request_id ) : # Changed to synchronous
"""
Mark a request as aborted . This will helps to release kv cache in
prefiller node .
"""
# No lock needed - atomic operation
self . prefillers [ server_idx ] . aborted_requests . add ( request_id )
def aquire_aborted_prefiller_requests (
self , server_idx : int ) : # Changed to synchronous
"""
Get the set of aborted requests and clear it .
This is used to release kv cache in prefiller node .
"""
# No lock needed - atomic operation
aborted_requests = self . prefillers [ server_idx ] . aborted_requests . copy ( )
self . prefillers [ server_idx ] . aborted_requests . clear ( )
return aborted_requests
async def next_req_id ( self ) :
async with self . req_id_lock :
2025-08-29 17:19:23 +08:00
return str ( uuid . uuid4 ( ) )
2025-08-04 10:35:53 +08:00
def select_prefiller ( self , token_count ) : # Changed to synchronous
# No lock needed - entire function is atomic
if not self . prefiller_heap :
raise RuntimeError ( " No prefiller servers available " )
priority , chosen , server = heapq . heappop ( self . prefiller_heap )
# Update the chosen server atomically
self . prefillers [ chosen ] . active_tokens + = token_count
self . prefillers [ chosen ] . active_kv_cache + = token_count
# Update priority and re-add to heap
self . _update_prefiller_priority ( chosen )
return chosen
def release_prefiller ( self , idx , token_count ) : # Changed to synchronous
# No lock needed - atomic operation
self . prefillers [ idx ] . active_tokens - = token_count
# Update priority queue after releasing
self . _update_prefiller_priority ( idx )
def release_prefiller_kv ( self , idx , token_count ) : # Changed to synchronous
# No lock needed - atomic operation
if self . prefillers [ idx ] . active_kv_cache > 0 :
self . prefillers [ idx ] . active_kv_cache - = token_count
# Update priority queue after releasing
self . _update_prefiller_priority ( idx )
def select_decoder ( self , token_count ) : # Changed to synchronous
# No lock needed - entire function is atomic
if not self . decoder_heap :
raise RuntimeError ( " No decoder servers available " )
priority , chosen , server = heapq . heappop ( self . decoder_heap )
# Update the chosen server atomically
self . decoders [ chosen ] . active_tokens + = token_count
# Update priority and re-add to heap
self . _update_decoder_priority ( chosen )
return chosen
def release_decoder ( self , idx , token_count ) : # Changed to synchronous
# No lock needed - atomic operation
self . decoders [ idx ] . active_tokens - = token_count
# Update priority queue after releasing
self . _update_decoder_priority ( idx )
# Omni_infer's calculate_input_scores function
def calculate_prefill_scores ( self , request_length : int ) - > float :
length_score = request_length / 4.0
input_score = length_score * 0.0345 + 120.0745
return input_score
def calculate_decode_scores ( self , request_length : int ) - > float :
return request_length
proxy_state = None
def parse_args ( ) :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --port " , type = int , default = 8000 )
parser . add_argument ( " --host " , type = str , default = " localhost " )
parser . add_argument ( " --prefiller-hosts " ,
type = str ,
nargs = " + " ,
default = [ " localhost " ] )
parser . add_argument ( " --prefiller-ports " ,
type = int ,
nargs = " + " ,
default = [ 8001 ] )
parser . add_argument ( " --decoder-hosts " ,
type = str ,
nargs = " + " ,
default = [ " localhost " ] )
parser . add_argument ( " --decoder-ports " , type = int , nargs = " + " , default = [ 8002 ] )
parser . add_argument ( " --max-retries " ,
type = int ,
default = 3 ,
help = " Maximum number of retries for HTTP requests " )
parser . add_argument (
" --retry-delay " ,
type = float ,
default = 0.001 ,
help = " Base delay (seconds) for exponential backoff retries " )
args = parser . parse_args ( )
if len ( args . prefiller_hosts ) != len ( args . prefiller_ports ) :
raise ValueError (
" Number of prefiller hosts must match number of prefiller ports " )
if len ( args . decoder_hosts ) != len ( args . decoder_ports ) :
raise ValueError (
" Number of decoder hosts must match number of decoder ports " )
args . prefiller_instances = list (
zip ( args . prefiller_hosts , args . prefiller_ports ) )
args . decoder_instances = list ( zip ( args . decoder_hosts , args . decoder_ports ) )
return args
@asynccontextmanager
async def lifespan ( app : FastAPI ) :
global proxy_state
proxy_state = ProxyState ( global_args . prefiller_instances ,
global_args . decoder_instances )
print (
f " Initialized { len ( proxy_state . prefillers ) } prefill clients and { len ( proxy_state . decoders ) } decode clients. "
)
yield
for p in proxy_state . prefillers :
await p . client . aclose ( )
for d in proxy_state . decoders :
await d . client . aclose ( )
2025-08-29 17:19:23 +08:00
async def listen_for_disconnect ( request : Request ) - > None :
""" Return if a disconnect message is received """
while True :
message = await request . receive ( )
if message [ " type " ] == " http.disconnect " :
break
def with_cancellation ( handler_func ) :
@functools.wraps ( handler_func )
async def wrapper ( * args , * * kwargs ) :
request = kwargs [ " request " ]
handler_task = asyncio . create_task ( handler_func ( * args , * * kwargs ) )
cancellation_task = asyncio . create_task ( listen_for_disconnect ( request ) )
done , pending = await asyncio . wait ( [ handler_task , cancellation_task ] ,
return_when = asyncio . FIRST_COMPLETED )
for task in pending :
task . cancel ( )
if handler_task in done :
return handler_task . result ( )
return None
return wrapper
2025-08-04 10:35:53 +08:00
app = FastAPI ( lifespan = lifespan )
async def send_request_to_service ( client : httpx . AsyncClient ,
prefiller_id : int ,
endpoint : str ,
req_data : dict ,
request_id : str ,
max_retries : int = 3 ,
base_delay : float = 0.2 ) :
aborted_requests = proxy_state . aquire_aborted_prefiller_requests (
prefiller_id )
req_data = req_data . copy ( )
req_data [ ' kv_transfer_params ' ] = {
" do_remote_decode " : True ,
" do_remote_prefill " : False ,
" remote_engine_id " : None ,
" remote_block_ids " : None ,
" remote_host " : None ,
" remote_port " : None ,
" aborted_request " : list ( aborted_requests ) ,
}
req_data [ " stream " ] = False
req_data [ " max_tokens " ] = 1
fix all cards super_pod_id same on A3 & proxy support min_tokens (#2939)
### What this PR does / why we need it?
fix all cards super_pod_id same on A3 & proxy support min_tokens
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
2*A3 gen ranktable
before:
"prefill_device_list": [
{
"server_id": "xxx",
"device_id": "0",
"device_ip": "xxx",
"super_pod_id": "0",
"super_device_id": "106758159",
"cluster_id": "1"
},
{
"server_id": "xxx",
"device_id": "1",
"device_ip": "xxx",
"super_pod_id": "0",
"super_device_id": "106758159",
"cluster_id": "2"
}...
after:
"prefill_device_list": [
{
"server_id": "xxx",
"device_id": "0",
"device_ip": "xxx",
"super_pod_id": "0",
"super_device_id": "104857600",
"cluster_id": "1"
},
{
"server_id": "xxx",
"device_id": "1",
"device_ip": "xxx",
"super_pod_id": "0",
"super_device_id": "104923137",
"cluster_id": "2"
}...
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
2025-09-16 01:09:18 +08:00
req_data [ " min_tokens " ] = 1
2025-11-18 16:33:00 +08:00
if " max_completion_tokens " in req_data :
req_data [ " max_completion_tokens " ] = 1
2025-08-04 10:35:53 +08:00
if " stream_options " in req_data :
del req_data [ " stream_options " ]
headers = {
" Authorization " : f " Bearer { os . environ . get ( ' OPENAI_API_KEY ' ) } " ,
" X-Request-Id " : request_id
}
last_exc = None
for attempt in range ( 1 , max_retries + 1 ) :
try :
response = await client . post ( endpoint ,
json = req_data ,
headers = headers )
response . raise_for_status ( )
return response
except ( httpx . RequestError , httpx . HTTPStatusError ) as e :
logger . warning (
f " Attempt { attempt } failed for { endpoint } : { str ( e ) } " )
last_exc = e
if attempt < max_retries :
await asyncio . sleep ( base_delay * ( 2 * * ( attempt - 1 ) ) )
else :
logger . error (
f " All { max_retries } attempts failed for { endpoint } . " )
raise last_exc
async def stream_service_response_with_retry ( client : httpx . AsyncClient ,
endpoint : str ,
req_data : dict ,
request_id : str ,
max_retries : int = 3 ,
base_delay : float = 0.2 ) :
headers = {
" Authorization " : f " Bearer { os . environ . get ( ' OPENAI_API_KEY ' ) } " ,
" X-Request-Id " : request_id
}
for attempt in range ( 1 , max_retries + 1 ) :
try :
async with client . stream ( " POST " ,
endpoint ,
json = req_data ,
headers = headers ) as response :
response . raise_for_status ( )
first_chunk_sent = False
async for chunk in response . aiter_bytes ( ) :
first_chunk_sent = True
yield chunk
return # Success, exit after streaming
except ( httpx . RequestError , httpx . HTTPStatusError ) as e :
if attempt < max_retries :
logger . warning (
f " Attempt { attempt } failed for streaming { endpoint } : { str ( e ) } "
)
await asyncio . sleep ( base_delay * ( 2 * * ( attempt - 1 ) ) )
else :
logger . error (
f " All { max_retries } attempts failed for streaming { endpoint } . "
)
raise e
except Exception as e :
# If any chunk has been sent, do not retry, just log and drop
if ' first_chunk_sent ' in locals ( ) and first_chunk_sent :
logger . error (
f " Streaming to client interrupted after response started: { str ( e ) } "
)
return
else :
if attempt < max_retries :
logger . warning (
f " Attempt { attempt } failed for streaming { endpoint } : { str ( e ) } "
)
await asyncio . sleep ( base_delay * ( 2 * * ( attempt - 1 ) ) )
else :
logger . error (
f " All { max_retries } attempts failed for streaming { endpoint } . "
)
raise e
2025-10-18 15:56:44 +08:00
async def _handle_select_instance ( api : str , req_data : Any ,
request_length : int ) :
prefiller_score = proxy_state . calculate_prefill_scores ( request_length )
logger . debug (
f " Request length: { request_length } , Prefiller score: { prefiller_score } "
)
request_id = await proxy_state . next_req_id ( )
# Select prefiller
prefiller_idx = proxy_state . select_prefiller ( prefiller_score )
prefiller = proxy_state . prefillers [ prefiller_idx ]
# Send request to prefiller
response = await send_request_to_service (
prefiller . client ,
prefiller_idx ,
api ,
req_data ,
request_id ,
max_retries = global_args . max_retries ,
base_delay = global_args . retry_delay )
proxy_state . release_prefiller ( prefiller_idx , prefiller_score )
response_json = response . json ( )
kv_transfer_params = response_json . get ( ' kv_transfer_params ' , { } )
if kv_transfer_params :
req_data [ " kv_transfer_params " ] = kv_transfer_params
# Select decoder
decoder_score = proxy_state . calculate_decode_scores ( request_length )
logger . debug ( " Decoder score: %f " , decoder_score )
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state . select_decoder ( decoder_score )
decoder = proxy_state . decoders [ decoder_idx ]
logger . debug ( " Using %s %s " , prefiller . url , decoder . url )
return InstanceInfo ( request_id = request_id ,
prefiller_idx = prefiller_idx ,
prefiller_score = prefiller_score ,
prefiller = prefiller ,
decoder = decoder ,
decoder_idx = decoder_idx ,
decoder_score = decoder_score )
@dataclass
class InstanceInfo :
request_id : str
prefiller_idx : int
prefiller_score : float
prefiller : ServerState
decoder_idx : int
decoder_score : float
decoder : ServerState
2025-08-04 10:35:53 +08:00
async def _handle_completions ( api : str , request : Request ) :
try :
req_data = await request . json ( )
req_body = await request . body ( )
request_length = len ( req_body )
2025-10-18 15:56:44 +08:00
instance_info = await _handle_select_instance ( api , req_data ,
request_length )
stream_flag = bool ( req_data . get ( " stream " , False ) )
chat_flag = " messages " in req_data
if " prompt " in req_data :
origin_prompt = req_data [ " prompt " ]
elif chat_flag :
messages = req_data [ " messages " ]
origin_prompt = messages [ 0 ] . get ( " content " , " " )
else :
origin_prompt = " "
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data . get ( " max_tokens " , 16 )
2025-08-04 10:35:53 +08:00
async def generate_stream ( ) :
2025-10-18 15:56:44 +08:00
nonlocal instance_info
generated_token = " "
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
2025-08-04 10:35:53 +08:00
# Only one await per chunk, minimal logic in loop
try :
2025-10-18 15:56:44 +08:00
while retry :
retry = False
async for chunk in stream_service_response_with_retry (
instance_info . decoder . client ,
api ,
req_data ,
request_id = instance_info . request_id ,
max_retries = global_args . max_retries ,
base_delay = global_args . retry_delay ) :
if not released_kv and chunk :
proxy_state . release_prefiller_kv (
instance_info . prefiller_idx ,
instance_info . prefiller_score )
released_kv = True
2025-10-27 16:56:50 +08:00
try :
chunk_str = chunk . decode ( " utf-8 " ) . strip ( )
except UnicodeDecodeError :
logger . debug (
f " Skipping chunk: { chunk } " )
yield chunk
continue
2025-10-18 15:56:44 +08:00
if not chunk_str :
continue
if chunk_str . startswith ( " data: " ) :
chunk_str = chunk_str [ len ( " data: " ) : ]
try :
chunk_json = json . loads ( chunk_str )
except json . JSONDecodeError :
# if chunk is [done], skip it.
2025-10-25 09:15:42 +08:00
logger . debug (
2025-10-18 15:56:44 +08:00
f " Skipping chunk: { chunk_str } " )
yield chunk
continue
choices = chunk_json . get ( " choices " , [ ] )
if not choices :
yield chunk
continue
choice = choices [ 0 ]
delta = choice . get ( " delta " ) or { }
message = choice . get ( " message " ) or { }
content = (
delta . get ( " content " )
or message . get ( " content " )
or choice . get ( " text " )
or " "
)
generated_token + = content
stop_reason = choice . get (
" stop_reason " )
usage = chunk_json . get ( " usage " , { } )
completion_tokens = ( completion_tokens + 1 ) if stream_flag else \
( completion_tokens + usage . get ( " completion_tokens " ) )
if stop_reason == " recomputed " :
retry = True
retry_count + = 1
if chat_flag :
messages [ 0 ] [
" content " ] = origin_prompt + generated_token
else :
req_data [
" prompt " ] = origin_prompt + generated_token
req_data [
" max_tokens " ] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len (
json . dumps ( req_data ) . encode ( " utf-8 " ) )
instance_info = await _handle_select_instance (
api , req_data , tmp_request_length )
break
if retry_count > 0 and not stream_flag :
if chat_flag :
choice [ " message " ] [
" content " ] = generated_token
else :
choice [ " text " ] = generated_token
chunk = json . dumps ( chunk_json ) . encode ( " utf-8 " )
yield chunk
2025-08-04 10:35:53 +08:00
except Exception as e :
logger . error (
2025-10-18 15:56:44 +08:00
f " Error during streaming from decoder { instance_info . decoder . url } : { str ( e ) } the aborted request { instance_info . request_id } will be routing to the target prefiller when new request is ready to dispatch to it "
2025-08-04 10:35:53 +08:00
)
2025-10-18 15:56:44 +08:00
proxy_state . abort_prefiller_request (
instance_info . prefiller_idx , instance_info . request_id )
proxy_state . release_prefiller_kv ( instance_info . prefiller_idx ,
instance_info . prefiller_score )
2025-08-04 10:35:53 +08:00
# After streaming done, release tokens
2025-10-18 15:56:44 +08:00
proxy_state . release_decoder ( instance_info . decoder_idx ,
instance_info . decoder_score )
2025-08-04 10:35:53 +08:00
return StreamingResponse ( generate_stream ( ) ,
media_type = " application/json " )
except Exception as e :
import traceback
exc_info = sys . exc_info ( )
print ( " Error occurred in disagg prefill proxy server "
f " - { api } endpoint " )
print ( e )
print ( " " . join ( traceback . format_exception ( * exc_info ) ) )
raise
@app.post ( " /v1/completions " )
2025-08-29 17:19:23 +08:00
@with_cancellation
2025-08-04 10:35:53 +08:00
async def handle_completions ( request : Request ) :
return await _handle_completions ( " /completions " , request )
@app.post ( " /v1/chat/completions " )
2025-08-29 17:19:23 +08:00
@with_cancellation
2025-08-04 10:35:53 +08:00
async def handle_chat_completions ( request : Request ) :
return await _handle_completions ( " /chat/completions " , request )
@app.get ( " /healthcheck " )
async def healthcheck ( ) :
return {
" status " : " ok " ,
" prefill_instances " : len ( proxy_state . prefillers ) ,
" decode_instances " : len ( proxy_state . decoders )
}
if __name__ == ' __main__ ' :
global global_args
global_args = parse_args ( )
import uvicorn
2025-10-18 15:56:44 +08:00
uvicorn . run ( app , host = global_args . host , port = global_args . port )