From 6d3b35fae953bc2f8dcebc3b917f267b8c86f261 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Tue, 8 Apr 2025 09:42:34 -0700 Subject: [PATCH] [PD] Simplify mini LB (#4911) Co-authored-by: Liangsheng Yin --- python/sglang/srt/disaggregation/mini_lb.py | 174 ++++++-------------- 1 file changed, 51 insertions(+), 123 deletions(-) diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 55db8fa03..d90277a77 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -1,5 +1,5 @@ """ -Minimal HTTP load balancer for prefill and decode servers for testing purpose. +Minimal HTTP load balancer for prefill and decode servers for testing. """ import asyncio @@ -22,64 +22,59 @@ class MiniLoadBalancer: def select_pair(self): return random.choice(self.prefill_servers), random.choice(self.decode_servers) - async def generate_request(self, request_data): - prefill_server, decode_server = self.select_pair() - - # Parse and transform prefill_server - parsed_url = urllib.parse.urlparse(prefill_server) - hostname = parsed_url.hostname - bootstrap_host = f"{hostname}" - - modified_request = request_data.copy() - modified_request.update( - { - "bootstrap_host": bootstrap_host, - "bootstrap_room": random.randint(0, 2**63 - 1), - } - ) + async def generate( + self, modified_request, prefill_server, decode_server + ) -> ORJSONResponse: async with aiohttp.ClientSession() as session: - # Create the tasks tasks = [ session.post(f"{prefill_server}/generate", json=modified_request), session.post(f"{decode_server}/generate", json=modified_request), ] + # Wait for both responses to complete. Prefill should end first. + prefill_response, decode_response = await asyncio.gather(*tasks) - prefill_response = None - decode_response = None + return ORJSONResponse( + content=await decode_response.json(), + status_code=decode_response.status, + ) - # Process responses as they arrive - for i, response in enumerate(asyncio.as_completed(tasks)): - response = await response - # Check if this is the prefill or decode response based on order created - if i == 0: # First completed task - if str(response.url).startswith(prefill_server): - prefill_response = response - if response.status != 200: - raise HTTPException( - status_code=response.status, - detail=f"Prefill server error: Status {response.status} Details: {await response.text()}", - ) - else: - decode_response = response - if response.status != 200: - raise HTTPException( - status_code=response.status, - detail=f"Decode server error: Status {response.status} Details: {await response.text()}", - ) - else: # Second completed task - if str(response.url).startswith(prefill_server): - prefill_response = response - else: - decode_response = response + async def generate_stream(self, modified_request, prefill_server, decode_server): + async def stream_results(): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout( + total=3600 + ) # Add timeout for request reliability + ) as session: + try: + # Create the tasks for both prefill and decode requests + tasks = [ + session.post( + f"{prefill_server}/generate", json=modified_request + ), + session.post( + f"{decode_server}/generate", json=modified_request + ), + ] + # Wait for both responses to complete. Since this is streaming, they return immediately. + prefill_response, decode_response = await asyncio.gather(*tasks) + async for chunk in decode_response.content: + yield chunk + except Exception as e: + error_msg = { + "error": {"message": f"Stream processing error: {str(e)}"} + } + yield b"data: " + orjson.dumps( + error_msg, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + finally: + if prefill_response is not None: + await prefill_response.release() - if response.status != 200: - raise HTTPException( - status_code=response.status, - detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}", - ) - - return await decode_response.json() + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + ) app = FastAPI() @@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict): } ) - # Check if streaming is requested if request_data.get("stream", False): - - async def stream_results(): - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=3600) - ) as session: - try: - # Create the tasks - tasks = [ - session.post( - f"{prefill_server}/generate", json=modified_request - ), - session.post( - f"{decode_server}/generate", json=modified_request - ), - ] - - prefill_response = None - decode_response = None - - # Process responses as they arrive - for i, response_task in enumerate(asyncio.as_completed(tasks)): - response = await response_task - - # Check the response immediately - if str(response.url).startswith(prefill_server): - prefill_response = response - if response.status != 200: - error_msg = { - "error": { - "message": f"Prefill server error: Status {response.status}, Details: {await response.text()}" - } - } - yield b"data: " + orjson.dumps( - error_msg, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - return - else: - decode_response = response - if response.status != 200: - error_msg = { - "error": { - "message": f"Decode server error: Status {response.status}" - } - } - yield b"data: " + orjson.dumps( - error_msg, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - return - - # Stream successful decode server response - async for line in decode_response.content: - yield line - yield b"data: [DONE]\n\n" - - except Exception as e: - error_msg = { - "error": {"message": f"Stream processing error: {str(e)}"} - } - yield b"data: " + orjson.dumps( - error_msg, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - finally: - if prefill_response is not None: - await prefill_response.release() - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", + return await load_balancer.generate_stream( + modified_request, prefill_server, decode_server + ) + else: + return await load_balancer.generate( + modified_request, prefill_server, decode_server ) - - # Non-streaming case - result = await load_balancer.generate_request(request_data) - return ORJSONResponse(content=result) @app.get("/v1/models")