Support PD bootstrap fields on /v1/chat/completions endpoint (#5488)
This commit is contained in:
@@ -23,8 +23,9 @@ class MiniLoadBalancer:
|
|||||||
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self, modified_request, prefill_server, decode_server
|
self, modified_request, prefill_server, decode_server, endpoint
|
||||||
) -> ORJSONResponse:
|
) -> ORJSONResponse:
|
||||||
|
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
@@ -32,8 +33,8 @@ class MiniLoadBalancer:
|
|||||||
) # Add timeout for request reliability
|
) # Add timeout for request reliability
|
||||||
) as session:
|
) as session:
|
||||||
tasks = [
|
tasks = [
|
||||||
session.post(f"{prefill_server}/generate", json=modified_request),
|
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||||
session.post(f"{decode_server}/generate", json=modified_request),
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||||
]
|
]
|
||||||
# Wait for both responses to complete. Prefill should end first.
|
# Wait for both responses to complete. Prefill should end first.
|
||||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||||
@@ -43,7 +44,11 @@ class MiniLoadBalancer:
|
|||||||
status_code=decode_response.status,
|
status_code=decode_response.status,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_stream(self, modified_request, prefill_server, decode_server):
|
async def generate_stream(
|
||||||
|
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
||||||
|
):
|
||||||
|
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
||||||
|
|
||||||
async def stream_results():
|
async def stream_results():
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
@@ -54,10 +59,10 @@ class MiniLoadBalancer:
|
|||||||
# Create the tasks for both prefill and decode requests
|
# Create the tasks for both prefill and decode requests
|
||||||
tasks = [
|
tasks = [
|
||||||
session.post(
|
session.post(
|
||||||
f"{prefill_server}/generate", json=modified_request
|
f"{prefill_server}/{endpoint}", json=modified_request
|
||||||
),
|
),
|
||||||
session.post(
|
session.post(
|
||||||
f"{decode_server}/generate", json=modified_request
|
f"{decode_server}/{endpoint}", json=modified_request
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||||
@@ -190,6 +195,37 @@ async def handle_generate_request(request_data: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def handle_completion_request(request_data: dict):
|
||||||
|
prefill_server, decode_server = load_balancer.select_pair()
|
||||||
|
|
||||||
|
# Parse and transform prefill_server for bootstrap data
|
||||||
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
|
hostname = parsed_url.hostname
|
||||||
|
modified_request = request_data.copy()
|
||||||
|
modified_request.update(
|
||||||
|
{
|
||||||
|
"bootstrap_host": hostname,
|
||||||
|
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if request_data.get("stream", False):
|
||||||
|
return await load_balancer.generate_stream(
|
||||||
|
modified_request,
|
||||||
|
prefill_server,
|
||||||
|
decode_server,
|
||||||
|
endpoint="v1/chat/completions",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await load_balancer.generate(
|
||||||
|
modified_request,
|
||||||
|
prefill_server,
|
||||||
|
decode_server,
|
||||||
|
endpoint="v1/chat/completions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _generate_bootstrap_room():
|
def _generate_bootstrap_room():
|
||||||
return random.randint(0, 2**63 - 1)
|
return random.randint(0, 2**63 - 1)
|
||||||
|
|
||||||
|
|||||||
@@ -1174,6 +1174,8 @@ def v1_chat_generate_request(
|
|||||||
rid=request_ids,
|
rid=request_ids,
|
||||||
modalities=modalities_list,
|
modalities=modalities_list,
|
||||||
lora_path=lora_paths,
|
lora_path=lora_paths,
|
||||||
|
bootstrap_host=all_requests[0].bootstrap_host,
|
||||||
|
bootstrap_room=all_requests[0].bootstrap_room,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
|
|||||||
@@ -362,6 +362,10 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
separate_reasoning: bool = True
|
separate_reasoning: bool = True
|
||||||
stream_reasoning: bool = True
|
stream_reasoning: bool = True
|
||||||
|
|
||||||
|
# For PD disaggregation
|
||||||
|
bootstrap_host: Optional[str] = None
|
||||||
|
bootstrap_room: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class FunctionResponse(BaseModel):
|
class FunctionResponse(BaseModel):
|
||||||
"""Function response."""
|
"""Function response."""
|
||||||
|
|||||||
Reference in New Issue
Block a user