diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index c591b052f..883734f4e 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -274,8 +274,7 @@ async def handle_generate_request(request_data: dict): ) -@app.post("/v1/chat/completions") -async def handle_completion_request(request_data: dict): +async def _forward_to_backend(request_data: dict, endpoint_name: str): prefill_server, bootstrap_port, decode_server = load_balancer.select_pair() # Parse and transform prefill_server for bootstrap data @@ -286,7 +285,7 @@ async def handle_completion_request(request_data: dict): { "bootstrap_host": hostname, "bootstrap_port": bootstrap_port, - "bootstrap_room": random.randint(0, 2**63 - 1), + "bootstrap_room": _generate_bootstrap_room(), } ) @@ -295,17 +294,27 @@ async def handle_completion_request(request_data: dict): modified_request, prefill_server, decode_server, - endpoint="v1/chat/completions", + endpoint=endpoint_name, ) else: return await load_balancer.generate( modified_request, prefill_server, decode_server, - endpoint="v1/chat/completions", + endpoint=endpoint_name, ) +@app.post("/v1/chat/completions") +async def handle_chat_completion_request(request_data: dict): + return await _forward_to_backend(request_data, "v1/chat/completions") + + +@app.post("/v1/completions") +async def handle_completion_request(request_data: dict): + return await _forward_to_backend(request_data, "v1/completions") + + def _generate_bootstrap_room(): return random.randint(0, 2**63 - 1) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 27336dc75..83ff70b39 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -604,6 +604,9 @@ def v1_generate_request( stream=all_requests[0].stream, rid=request_ids, lora_path=lora_paths, + bootstrap_host=all_requests[0].bootstrap_host, + bootstrap_port=all_requests[0].bootstrap_port, + bootstrap_room=all_requests[0].bootstrap_room, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index e5f228a30..35c04b054 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -183,12 +183,17 @@ class CompletionRequest(BaseModel): lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None session_params: Optional[Dict] = None + # For PD disaggregation + bootstrap_host: Optional[str] = None + bootstrap_port: Optional[int] = None + bootstrap_room: Optional[int] = None + class CompletionResponseChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Literal["stop", "length", "content_filter"] + finish_reason: Literal["stop", "length", "content_filter", "abort"] matched_stop: Union[None, int, str] = None