diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py index 3af8fbe..727233e 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py @@ -87,9 +87,11 @@ import argparse import asyncio +import functools import heapq import os import sys +import uuid from contextlib import asynccontextmanager from typing import List @@ -137,7 +139,6 @@ class ProxyState: ] self.req_to_prefiller = {} self.req_id_lock = asyncio.Lock() - self.req_id_counter = 0 # Removed selection locks - no longer needed for synchronous methods # Initialize priority queues for efficient server selection @@ -193,8 +194,7 @@ class ProxyState: async def next_req_id(self): async with self.req_id_lock: - self.req_id_counter += 1 - return str(self.req_id_counter) + return str(uuid.uuid4()) def select_prefiller(self, token_count): # Changed to synchronous # No lock needed - entire function is atomic @@ -313,6 +313,32 @@ async def lifespan(app: FastAPI): await d.client.aclose() +async def listen_for_disconnect(request: Request) -> None: + """Return if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + request = kwargs["request"] + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + done, pending = await asyncio.wait([handler_task, cancellation_task], + return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + if handler_task in done: + return handler_task.result() + return None + + return wrapper + + app = FastAPI(lifespan=lifespan) @@ -493,11 +519,13 @@ async def _handle_completions(api: str, request: Request): @app.post("/v1/completions") +@with_cancellation async def handle_completions(request: Request): return await _handle_completions("/completions", request) @app.post("/v1/chat/completions") +@with_cancellation async def handle_chat_completions(request: Request): return await _handle_completions("/chat/completions", request)