# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py # SPDX-License-Identifier: Apache-2.0 import argparse import itertools import os import uuid from contextlib import asynccontextmanager import httpx from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from vllm.logger import init_logger logger = init_logger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """ Lifespan context manager to handle startup and shutdown events. """ # Startup: Initialize client pools for prefiller and decoder services app.state.prefill_clients = [] app.state.decode_clients = [] limit = httpx.Limits(max_connections=100000, max_keepalive_connections=100000) # Create prefill clients for i, (host, port) in enumerate(global_args.prefiller_instances): prefiller_base_url = f'http://{host}:{port}/v1' app.state.prefill_clients.append({ 'client': httpx.AsyncClient(timeout=None, base_url=prefiller_base_url, limits=limit), 'host': host, 'port': port, 'id': i }) # Create decode clients for i, (host, port) in enumerate(global_args.decoder_instances): decoder_base_url = f'http://{host}:{port}/v1' app.state.decode_clients.append({ 'client': httpx.AsyncClient(timeout=None, base_url=decoder_base_url, limits=limit), 'host': host, 'port': port, 'id': i }) # Initialize round-robin iterators app.state.prefill_iterator = itertools.cycle( range(len(app.state.prefill_clients))) app.state.decode_iterator = itertools.cycle( range(len(app.state.decode_clients))) print(f"Initialized {len(app.state.prefill_clients)} prefill clients " f"and {len(app.state.decode_clients)} decode clients.") yield # Shutdown: Close all clients for client_info in app.state.prefill_clients: await client_info['client'].aclose() for client_info in app.state.decode_clients: await client_info['client'].aclose() # Update FastAPI app initialization to use lifespan app = FastAPI(lifespan=lifespan) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=8000) parser.add_argument("--host", type=str, default="localhost") # For prefiller instances parser.add_argument("--prefiller-hosts", "--prefiller-host", type=str, nargs="+", default=["localhost"]) parser.add_argument("--prefiller-ports", "--prefiller-port", type=int, nargs="+", default=[8100]) # For decoder instances parser.add_argument("--decoder-hosts", "--decoder-host", type=str, nargs="+", default=["localhost"]) parser.add_argument("--decoder-ports", "--decoder-port", type=int, nargs="+", default=[8200]) args = parser.parse_args() # Validate and pair hosts with ports if len(args.prefiller_hosts) != len(args.prefiller_ports): raise ValueError( "Number of prefiller hosts must match number of prefiller ports") if len(args.decoder_hosts) != len(args.decoder_ports): raise ValueError( "Number of decoder hosts must match number of decoder ports") # Create tuples of (host, port) for each service type args.prefiller_instances = list( zip(args.prefiller_hosts, args.prefiller_ports)) args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports)) return args def get_next_client(app, service_type: str): """ Get the next client in round-robin fashion. Args: app: The FastAPI app instance service_type: Either 'prefill' or 'decode' Returns: The next client to use """ if service_type == 'prefill': client_idx = next(app.state.prefill_iterator) return app.state.prefill_clients[client_idx] elif service_type == 'decode': client_idx = next(app.state.decode_iterator) return app.state.decode_clients[client_idx] else: raise ValueError(f"Unknown service type: {service_type}") async def send_request_to_service(client_info: dict, endpoint: str, req_data: dict, request_id: str): """ Send a request to a service using a client from the pool. """ req_data = req_data.copy() req_data['kv_transfer_params'] = { "do_remote_decode": True, "do_remote_prefill": False, "remote_engine_id": None, "remote_block_ids": None, "remote_host": None, "remote_port": None } req_data["stream"] = False req_data["max_tokens"] = 1 if "stream_options" in req_data: del req_data["stream_options"] headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id } response = await client_info['client'].post(endpoint, json=req_data, headers=headers) response.raise_for_status() return response async def stream_service_response(client_info: dict, endpoint: str, req_data: dict, request_id: str): """ Asynchronously stream response from a service using a client from the pool. """ headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id } async with client_info['client'].stream("POST", endpoint, json=req_data, headers=headers) as response: response.raise_for_status() async for chunk in response.aiter_bytes(): yield chunk async def _handle_completions(api: str, request: Request): try: req_data = await request.json() request_id = str(uuid.uuid4()) # Get the next prefill client in round-robin fashion prefill_client_info = get_next_client(request.app, 'prefill') # Send request to prefill service response = await send_request_to_service(prefill_client_info, api, req_data, request_id) # Extract the needed fields response_json = response.json() kv_transfer_params = response_json.get('kv_transfer_params', {}) if kv_transfer_params: req_data["kv_transfer_params"] = kv_transfer_params # Get the next decode client in round-robin fashion decode_client_info = get_next_client(request.app, 'decode') logger.debug("Using %s %s", prefill_client_info, decode_client_info) # Stream response from decode service async def generate_stream(): async for chunk in stream_service_response(decode_client_info, api, req_data, request_id=request_id): yield chunk return StreamingResponse(generate_stream(), media_type="application/json") except Exception as e: import sys import traceback exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server" f" - {api} endpoint") print(e) print("".join(traceback.format_exception(*exc_info))) raise @app.post("/v1/completions") async def handle_completions(request: Request): return await _handle_completions("/completions", request) @app.post("/v1/chat/completions") async def handle_chat_completions(request: Request): return await _handle_completions("/chat/completions", request) @app.get("/healthcheck") async def healthcheck(): """Simple endpoint to check if the server is running.""" return { "status": "ok", "prefill_instances": len(app.state.prefill_clients), "decode_instances": len(app.state.decode_clients) } if __name__ == '__main__': global global_args global_args = parse_args() import uvicorn uvicorn.run(app, host=global_args.host, port=global_args.port)