[PD] Simplify mini LB (#4911)

Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
Byron Hsu
2025-04-08 09:42:34 -07:00
committed by GitHub
parent a73c4df438
commit 6d3b35fae9

View File

@@ -1,5 +1,5 @@
""" """
Minimal HTTP load balancer for prefill and decode servers for testing purpose. Minimal HTTP load balancer for prefill and decode servers for testing.
""" """
import asyncio import asyncio
@@ -22,64 +22,59 @@ class MiniLoadBalancer:
def select_pair(self): def select_pair(self):
return random.choice(self.prefill_servers), random.choice(self.decode_servers) return random.choice(self.prefill_servers), random.choice(self.decode_servers)
async def generate_request(self, request_data): async def generate(
prefill_server, decode_server = self.select_pair() self, modified_request, prefill_server, decode_server
) -> ORJSONResponse:
# Parse and transform prefill_server
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
bootstrap_host = f"{hostname}"
modified_request = request_data.copy()
modified_request.update(
{
"bootstrap_host": bootstrap_host,
"bootstrap_room": random.randint(0, 2**63 - 1),
}
)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Create the tasks
tasks = [ tasks = [
session.post(f"{prefill_server}/generate", json=modified_request), session.post(f"{prefill_server}/generate", json=modified_request),
session.post(f"{decode_server}/generate", json=modified_request), session.post(f"{decode_server}/generate", json=modified_request),
] ]
# Wait for both responses to complete. Prefill should end first.
prefill_response, decode_response = await asyncio.gather(*tasks)
prefill_response = None return ORJSONResponse(
decode_response = None content=await decode_response.json(),
status_code=decode_response.status,
)
# Process responses as they arrive async def generate_stream(self, modified_request, prefill_server, decode_server):
for i, response in enumerate(asyncio.as_completed(tasks)): async def stream_results():
response = await response async with aiohttp.ClientSession(
# Check if this is the prefill or decode response based on order created timeout=aiohttp.ClientTimeout(
if i == 0: # First completed task total=3600
if str(response.url).startswith(prefill_server): ) # Add timeout for request reliability
prefill_response = response ) as session:
if response.status != 200: try:
raise HTTPException( # Create the tasks for both prefill and decode requests
status_code=response.status, tasks = [
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}", session.post(
) f"{prefill_server}/generate", json=modified_request
else: ),
decode_response = response session.post(
if response.status != 200: f"{decode_server}/generate", json=modified_request
raise HTTPException( ),
status_code=response.status, ]
detail=f"Decode server error: Status {response.status} Details: {await response.text()}", # Wait for both responses to complete. Since this is streaming, they return immediately.
) prefill_response, decode_response = await asyncio.gather(*tasks)
else: # Second completed task async for chunk in decode_response.content:
if str(response.url).startswith(prefill_server): yield chunk
prefill_response = response except Exception as e:
else: error_msg = {
decode_response = response "error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
if response.status != 200: return StreamingResponse(
raise HTTPException( stream_results(),
status_code=response.status, media_type="text/event-stream",
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}", )
)
return await decode_response.json()
app = FastAPI() app = FastAPI()
@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict):
} }
) )
# Check if streaming is requested
if request_data.get("stream", False): if request_data.get("stream", False):
return await load_balancer.generate_stream(
async def stream_results(): modified_request, prefill_server, decode_server
async with aiohttp.ClientSession( )
timeout=aiohttp.ClientTimeout(total=3600) else:
) as session: return await load_balancer.generate(
try: modified_request, prefill_server, decode_server
# Create the tasks
tasks = [
session.post(
f"{prefill_server}/generate", json=modified_request
),
session.post(
f"{decode_server}/generate", json=modified_request
),
]
prefill_response = None
decode_response = None
# Process responses as they arrive
for i, response_task in enumerate(asyncio.as_completed(tasks)):
response = await response_task
# Check the response immediately
if str(response.url).startswith(prefill_server):
prefill_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
else:
decode_response = response
if response.status != 200:
error_msg = {
"error": {
"message": f"Decode server error: Status {response.status}"
}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
return
# Stream successful decode server response
async for line in decode_response.content:
yield line
yield b"data: [DONE]\n\n"
except Exception as e:
error_msg = {
"error": {"message": f"Stream processing error: {str(e)}"}
}
yield b"data: " + orjson.dumps(
error_msg, option=orjson.OPT_NON_STR_KEYS
) + b"\n\n"
finally:
if prefill_response is not None:
await prefill_response.release()
return StreamingResponse(
stream_results(),
media_type="text/event-stream",
) )
# Non-streaming case
result = await load_balancer.generate_request(request_data)
return ORJSONResponse(content=result)
@app.get("/v1/models") @app.get("/v1/models")