[PD] Simplify mini LB (#4911)
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing purpose.
|
||||
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -22,64 +22,59 @@ class MiniLoadBalancer:
|
||||
def select_pair(self):
|
||||
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
||||
|
||||
async def generate_request(self, request_data):
|
||||
prefill_server, decode_server = self.select_pair()
|
||||
|
||||
# Parse and transform prefill_server
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = parsed_url.hostname
|
||||
bootstrap_host = f"{hostname}"
|
||||
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": bootstrap_host,
|
||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||
}
|
||||
)
|
||||
async def generate(
|
||||
self, modified_request, prefill_server, decode_server
|
||||
) -> ORJSONResponse:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create the tasks
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/generate", json=modified_request),
|
||||
session.post(f"{decode_server}/generate", json=modified_request),
|
||||
]
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
prefill_response = None
|
||||
decode_response = None
|
||||
return ORJSONResponse(
|
||||
content=await decode_response.json(),
|
||||
status_code=decode_response.status,
|
||||
)
|
||||
|
||||
# Process responses as they arrive
|
||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
||||
response = await response
|
||||
# Check if this is the prefill or decode response based on order created
|
||||
if i == 0: # First completed task
|
||||
if str(response.url).startswith(prefill_server):
|
||||
prefill_response = response
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
|
||||
)
|
||||
else:
|
||||
decode_response = response
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
|
||||
)
|
||||
else: # Second completed task
|
||||
if str(response.url).startswith(prefill_server):
|
||||
prefill_response = response
|
||||
else:
|
||||
decode_response = response
|
||||
async def generate_stream(self, modified_request, prefill_server, decode_server):
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=3600
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
try:
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(
|
||||
f"{prefill_server}/generate", json=modified_request
|
||||
),
|
||||
session.post(
|
||||
f"{decode_server}/generate", json=modified_request
|
||||
),
|
||||
]
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
async for chunk in decode_response.content:
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
error_msg = {
|
||||
"error": {"message": f"Stream processing error: {str(e)}"}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
finally:
|
||||
if prefill_response is not None:
|
||||
await prefill_response.release()
|
||||
|
||||
if response.status != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status,
|
||||
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
|
||||
)
|
||||
|
||||
return await decode_response.json()
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict):
|
||||
}
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
if request_data.get("stream", False):
|
||||
|
||||
async def stream_results():
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=3600)
|
||||
) as session:
|
||||
try:
|
||||
# Create the tasks
|
||||
tasks = [
|
||||
session.post(
|
||||
f"{prefill_server}/generate", json=modified_request
|
||||
),
|
||||
session.post(
|
||||
f"{decode_server}/generate", json=modified_request
|
||||
),
|
||||
]
|
||||
|
||||
prefill_response = None
|
||||
decode_response = None
|
||||
|
||||
# Process responses as they arrive
|
||||
for i, response_task in enumerate(asyncio.as_completed(tasks)):
|
||||
response = await response_task
|
||||
|
||||
# Check the response immediately
|
||||
if str(response.url).startswith(prefill_server):
|
||||
prefill_response = response
|
||||
if response.status != 200:
|
||||
error_msg = {
|
||||
"error": {
|
||||
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
|
||||
}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
return
|
||||
else:
|
||||
decode_response = response
|
||||
if response.status != 200:
|
||||
error_msg = {
|
||||
"error": {
|
||||
"message": f"Decode server error: Status {response.status}"
|
||||
}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
return
|
||||
|
||||
# Stream successful decode server response
|
||||
async for line in decode_response.content:
|
||||
yield line
|
||||
yield b"data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
error_msg = {
|
||||
"error": {"message": f"Stream processing error: {str(e)}"}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
finally:
|
||||
if prefill_response is not None:
|
||||
await prefill_response.release()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
media_type="text/event-stream",
|
||||
return await load_balancer.generate_stream(
|
||||
modified_request, prefill_server, decode_server
|
||||
)
|
||||
else:
|
||||
return await load_balancer.generate(
|
||||
modified_request, prefill_server, decode_server
|
||||
)
|
||||
|
||||
# Non-streaming case
|
||||
result = await load_balancer.generate_request(request_data)
|
||||
return ORJSONResponse(content=result)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
|
||||
Reference in New Issue
Block a user