diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 3b8422421..3b4407206 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -23,8 +23,9 @@ class MiniLoadBalancer: return random.choice(self.prefill_servers), random.choice(self.decode_servers) async def generate( - self, modified_request, prefill_server, decode_server + self, modified_request, prefill_server, decode_server, endpoint ) -> ORJSONResponse: + assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout( @@ -32,8 +33,8 @@ class MiniLoadBalancer: ) # Add timeout for request reliability ) as session: tasks = [ - session.post(f"{prefill_server}/generate", json=modified_request), - session.post(f"{decode_server}/generate", json=modified_request), + session.post(f"{prefill_server}/{endpoint}", json=modified_request), + session.post(f"{decode_server}/{endpoint}", json=modified_request), ] # Wait for both responses to complete. Prefill should end first. prefill_response, decode_response = await asyncio.gather(*tasks) @@ -43,7 +44,11 @@ class MiniLoadBalancer: status_code=decode_response.status, ) - async def generate_stream(self, modified_request, prefill_server, decode_server): + async def generate_stream( + self, modified_request, prefill_server, decode_server, endpoint="generate" + ): + assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}" + async def stream_results(): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout( @@ -54,10 +59,10 @@ class MiniLoadBalancer: # Create the tasks for both prefill and decode requests tasks = [ session.post( - f"{prefill_server}/generate", json=modified_request + f"{prefill_server}/{endpoint}", json=modified_request ), session.post( - f"{decode_server}/generate", json=modified_request + f"{decode_server}/{endpoint}", json=modified_request ), ] # Wait for both responses to complete. Since this is streaming, they return immediately. @@ -190,6 +195,37 @@ async def handle_generate_request(request_data: dict): ) +@app.post("/v1/chat/completions") +async def handle_completion_request(request_data: dict): + prefill_server, decode_server = load_balancer.select_pair() + + # Parse and transform prefill_server for bootstrap data + parsed_url = urllib.parse.urlparse(prefill_server) + hostname = parsed_url.hostname + modified_request = request_data.copy() + modified_request.update( + { + "bootstrap_host": hostname, + "bootstrap_room": random.randint(0, 2**63 - 1), + } + ) + + if request_data.get("stream", False): + return await load_balancer.generate_stream( + modified_request, + prefill_server, + decode_server, + endpoint="v1/chat/completions", + ) + else: + return await load_balancer.generate( + modified_request, + prefill_server, + decode_server, + endpoint="v1/chat/completions", + ) + + def _generate_bootstrap_room(): return random.randint(0, 2**63 - 1) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index c9a3dbb92..761ccbd2b 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -1174,6 +1174,8 @@ def v1_chat_generate_request( rid=request_ids, modalities=modalities_list, lora_path=lora_paths, + bootstrap_host=all_requests[0].bootstrap_host, + bootstrap_room=all_requests[0].bootstrap_room, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 38b926c82..33644dd11 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -362,6 +362,10 @@ class ChatCompletionRequest(BaseModel): separate_reasoning: bool = True stream_reasoning: bool = True + # For PD disaggregation + bootstrap_host: Optional[str] = None + bootstrap_room: Optional[int] = None + class FunctionResponse(BaseModel): """Function response."""