### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.
This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
275 lines
8.7 KiB
Python
275 lines
8.7 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import argparse
|
|
import itertools
|
|
import os
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
|
|
import httpx
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
Lifespan context manager to handle startup and shutdown events.
|
|
"""
|
|
# Startup: Initialize client pools for prefiller and decoder services
|
|
app.state.prefill_clients = []
|
|
app.state.decode_clients = []
|
|
limit = httpx.Limits(max_connections=100000,
|
|
max_keepalive_connections=100000)
|
|
|
|
# Create prefill clients
|
|
for i, (host, port) in enumerate(global_args.prefiller_instances):
|
|
prefiller_base_url = f'http://{host}:{port}/v1'
|
|
app.state.prefill_clients.append({
|
|
'client':
|
|
httpx.AsyncClient(timeout=None,
|
|
base_url=prefiller_base_url,
|
|
limits=limit),
|
|
'host':
|
|
host,
|
|
'port':
|
|
port,
|
|
'id':
|
|
i
|
|
})
|
|
|
|
# Create decode clients
|
|
for i, (host, port) in enumerate(global_args.decoder_instances):
|
|
decoder_base_url = f'http://{host}:{port}/v1'
|
|
app.state.decode_clients.append({
|
|
'client':
|
|
httpx.AsyncClient(timeout=None,
|
|
base_url=decoder_base_url,
|
|
limits=limit),
|
|
'host':
|
|
host,
|
|
'port':
|
|
port,
|
|
'id':
|
|
i
|
|
})
|
|
|
|
# Initialize round-robin iterators
|
|
app.state.prefill_iterator = itertools.cycle(
|
|
range(len(app.state.prefill_clients)))
|
|
app.state.decode_iterator = itertools.cycle(
|
|
range(len(app.state.decode_clients)))
|
|
|
|
print(f"Initialized {len(app.state.prefill_clients)} prefill clients "
|
|
f"and {len(app.state.decode_clients)} decode clients.")
|
|
|
|
yield
|
|
|
|
# Shutdown: Close all clients
|
|
for client_info in app.state.prefill_clients:
|
|
await client_info['client'].aclose()
|
|
|
|
for client_info in app.state.decode_clients:
|
|
await client_info['client'].aclose()
|
|
|
|
|
|
# Update FastAPI app initialization to use lifespan
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--port", type=int, default=8000)
|
|
parser.add_argument("--host", type=str, default="localhost")
|
|
|
|
# For prefiller instances
|
|
parser.add_argument("--prefiller-hosts",
|
|
"--prefiller-host",
|
|
type=str,
|
|
nargs="+",
|
|
default=["localhost"])
|
|
parser.add_argument("--prefiller-ports",
|
|
"--prefiller-port",
|
|
type=int,
|
|
nargs="+",
|
|
default=[8100])
|
|
|
|
# For decoder instances
|
|
parser.add_argument("--decoder-hosts",
|
|
"--decoder-host",
|
|
type=str,
|
|
nargs="+",
|
|
default=["localhost"])
|
|
parser.add_argument("--decoder-ports",
|
|
"--decoder-port",
|
|
type=int,
|
|
nargs="+",
|
|
default=[8200])
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Validate and pair hosts with ports
|
|
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
|
raise ValueError(
|
|
"Number of prefiller hosts must match number of prefiller ports")
|
|
|
|
if len(args.decoder_hosts) != len(args.decoder_ports):
|
|
raise ValueError(
|
|
"Number of decoder hosts must match number of decoder ports")
|
|
|
|
# Create tuples of (host, port) for each service type
|
|
args.prefiller_instances = list(
|
|
zip(args.prefiller_hosts, args.prefiller_ports))
|
|
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
|
|
|
return args
|
|
|
|
|
|
def get_next_client(app, service_type: str):
|
|
"""
|
|
Get the next client in round-robin fashion.
|
|
|
|
Args:
|
|
app: The FastAPI app instance
|
|
service_type: Either 'prefill' or 'decode'
|
|
|
|
Returns:
|
|
The next client to use
|
|
"""
|
|
if service_type == 'prefill':
|
|
client_idx = next(app.state.prefill_iterator)
|
|
return app.state.prefill_clients[client_idx]
|
|
elif service_type == 'decode':
|
|
client_idx = next(app.state.decode_iterator)
|
|
return app.state.decode_clients[client_idx]
|
|
else:
|
|
raise ValueError(f"Unknown service type: {service_type}")
|
|
|
|
|
|
async def send_request_to_service(client_info: dict, endpoint: str,
|
|
req_data: dict, request_id: str):
|
|
"""
|
|
Send a request to a service using a client from the pool.
|
|
"""
|
|
req_data = req_data.copy()
|
|
req_data['kv_transfer_params'] = {
|
|
"do_remote_decode": True,
|
|
"do_remote_prefill": False,
|
|
"remote_engine_id": None,
|
|
"remote_block_ids": None,
|
|
"remote_host": None,
|
|
"remote_port": None
|
|
}
|
|
req_data["stream"] = False
|
|
req_data["max_tokens"] = 1
|
|
if "stream_options" in req_data:
|
|
del req_data["stream_options"]
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
"X-Request-Id": request_id
|
|
}
|
|
|
|
response = await client_info['client'].post(endpoint,
|
|
json=req_data,
|
|
headers=headers)
|
|
response.raise_for_status()
|
|
|
|
return response
|
|
|
|
|
|
async def stream_service_response(client_info: dict, endpoint: str,
|
|
req_data: dict, request_id: str):
|
|
"""
|
|
Asynchronously stream response from a service using a client from the pool.
|
|
"""
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
"X-Request-Id": request_id
|
|
}
|
|
|
|
async with client_info['client'].stream("POST",
|
|
endpoint,
|
|
json=req_data,
|
|
headers=headers) as response:
|
|
response.raise_for_status()
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
|
|
|
|
async def _handle_completions(api: str, request: Request):
|
|
try:
|
|
req_data = await request.json()
|
|
request_id = str(uuid.uuid4())
|
|
|
|
# Get the next prefill client in round-robin fashion
|
|
prefill_client_info = get_next_client(request.app, 'prefill')
|
|
|
|
# Send request to prefill service
|
|
response = await send_request_to_service(prefill_client_info, api,
|
|
req_data, request_id)
|
|
|
|
# Extract the needed fields
|
|
response_json = response.json()
|
|
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
|
if kv_transfer_params:
|
|
req_data["kv_transfer_params"] = kv_transfer_params
|
|
|
|
# Get the next decode client in round-robin fashion
|
|
decode_client_info = get_next_client(request.app, 'decode')
|
|
|
|
logger.debug("Using %s %s", prefill_client_info, decode_client_info)
|
|
|
|
# Stream response from decode service
|
|
async def generate_stream():
|
|
async for chunk in stream_service_response(decode_client_info,
|
|
api,
|
|
req_data,
|
|
request_id=request_id):
|
|
yield chunk
|
|
|
|
return StreamingResponse(generate_stream(),
|
|
media_type="application/json")
|
|
|
|
except Exception as e:
|
|
import sys
|
|
import traceback
|
|
exc_info = sys.exc_info()
|
|
print("Error occurred in disagg prefill proxy server"
|
|
f" - {api} endpoint")
|
|
print(e)
|
|
print("".join(traceback.format_exception(*exc_info)))
|
|
raise
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
async def handle_completions(request: Request):
|
|
return await _handle_completions("/completions", request)
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def handle_chat_completions(request: Request):
|
|
return await _handle_completions("/chat/completions", request)
|
|
|
|
|
|
@app.get("/healthcheck")
|
|
async def healthcheck():
|
|
"""Simple endpoint to check if the server is running."""
|
|
return {
|
|
"status": "ok",
|
|
"prefill_instances": len(app.state.prefill_clients),
|
|
"decode_instances": len(app.state.decode_clients)
|
|
}
|
|
|
|
|
|
if __name__ == '__main__':
|
|
global global_args
|
|
global_args = parse_args()
|
|
|
|
import uvicorn
|
|
uvicorn.run(app, host=global_args.host, port=global_args.port) |