[PD] Simplify mini LB (#4911)
Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Minimal HTTP load balancer for prefill and decode servers for testing purpose.
|
Minimal HTTP load balancer for prefill and decode servers for testing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -22,64 +22,59 @@ class MiniLoadBalancer:
|
|||||||
def select_pair(self):
|
def select_pair(self):
|
||||||
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
||||||
|
|
||||||
async def generate_request(self, request_data):
|
async def generate(
|
||||||
prefill_server, decode_server = self.select_pair()
|
self, modified_request, prefill_server, decode_server
|
||||||
|
) -> ORJSONResponse:
|
||||||
# Parse and transform prefill_server
|
|
||||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
|
||||||
hostname = parsed_url.hostname
|
|
||||||
bootstrap_host = f"{hostname}"
|
|
||||||
|
|
||||||
modified_request = request_data.copy()
|
|
||||||
modified_request.update(
|
|
||||||
{
|
|
||||||
"bootstrap_host": bootstrap_host,
|
|
||||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
# Create the tasks
|
|
||||||
tasks = [
|
tasks = [
|
||||||
session.post(f"{prefill_server}/generate", json=modified_request),
|
session.post(f"{prefill_server}/generate", json=modified_request),
|
||||||
session.post(f"{decode_server}/generate", json=modified_request),
|
session.post(f"{decode_server}/generate", json=modified_request),
|
||||||
]
|
]
|
||||||
|
# Wait for both responses to complete. Prefill should end first.
|
||||||
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
prefill_response = None
|
return ORJSONResponse(
|
||||||
decode_response = None
|
content=await decode_response.json(),
|
||||||
|
status_code=decode_response.status,
|
||||||
|
)
|
||||||
|
|
||||||
# Process responses as they arrive
|
async def generate_stream(self, modified_request, prefill_server, decode_server):
|
||||||
for i, response in enumerate(asyncio.as_completed(tasks)):
|
async def stream_results():
|
||||||
response = await response
|
async with aiohttp.ClientSession(
|
||||||
# Check if this is the prefill or decode response based on order created
|
timeout=aiohttp.ClientTimeout(
|
||||||
if i == 0: # First completed task
|
total=3600
|
||||||
if str(response.url).startswith(prefill_server):
|
) # Add timeout for request reliability
|
||||||
prefill_response = response
|
) as session:
|
||||||
if response.status != 200:
|
try:
|
||||||
raise HTTPException(
|
# Create the tasks for both prefill and decode requests
|
||||||
status_code=response.status,
|
tasks = [
|
||||||
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
|
session.post(
|
||||||
)
|
f"{prefill_server}/generate", json=modified_request
|
||||||
else:
|
),
|
||||||
decode_response = response
|
session.post(
|
||||||
if response.status != 200:
|
f"{decode_server}/generate", json=modified_request
|
||||||
raise HTTPException(
|
),
|
||||||
status_code=response.status,
|
]
|
||||||
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
|
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||||
)
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||||
else: # Second completed task
|
async for chunk in decode_response.content:
|
||||||
if str(response.url).startswith(prefill_server):
|
yield chunk
|
||||||
prefill_response = response
|
except Exception as e:
|
||||||
else:
|
error_msg = {
|
||||||
decode_response = response
|
"error": {"message": f"Stream processing error: {str(e)}"}
|
||||||
|
}
|
||||||
|
yield b"data: " + orjson.dumps(
|
||||||
|
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||||
|
) + b"\n\n"
|
||||||
|
finally:
|
||||||
|
if prefill_response is not None:
|
||||||
|
await prefill_response.release()
|
||||||
|
|
||||||
if response.status != 200:
|
return StreamingResponse(
|
||||||
raise HTTPException(
|
stream_results(),
|
||||||
status_code=response.status,
|
media_type="text/event-stream",
|
||||||
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return await decode_response.json()
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@@ -169,81 +164,14 @@ async def handle_generate_request(request_data: dict):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if streaming is requested
|
|
||||||
if request_data.get("stream", False):
|
if request_data.get("stream", False):
|
||||||
|
return await load_balancer.generate_stream(
|
||||||
async def stream_results():
|
modified_request, prefill_server, decode_server
|
||||||
async with aiohttp.ClientSession(
|
)
|
||||||
timeout=aiohttp.ClientTimeout(total=3600)
|
else:
|
||||||
) as session:
|
return await load_balancer.generate(
|
||||||
try:
|
modified_request, prefill_server, decode_server
|
||||||
# Create the tasks
|
|
||||||
tasks = [
|
|
||||||
session.post(
|
|
||||||
f"{prefill_server}/generate", json=modified_request
|
|
||||||
),
|
|
||||||
session.post(
|
|
||||||
f"{decode_server}/generate", json=modified_request
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
prefill_response = None
|
|
||||||
decode_response = None
|
|
||||||
|
|
||||||
# Process responses as they arrive
|
|
||||||
for i, response_task in enumerate(asyncio.as_completed(tasks)):
|
|
||||||
response = await response_task
|
|
||||||
|
|
||||||
# Check the response immediately
|
|
||||||
if str(response.url).startswith(prefill_server):
|
|
||||||
prefill_response = response
|
|
||||||
if response.status != 200:
|
|
||||||
error_msg = {
|
|
||||||
"error": {
|
|
||||||
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yield b"data: " + orjson.dumps(
|
|
||||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
|
||||||
) + b"\n\n"
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
decode_response = response
|
|
||||||
if response.status != 200:
|
|
||||||
error_msg = {
|
|
||||||
"error": {
|
|
||||||
"message": f"Decode server error: Status {response.status}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yield b"data: " + orjson.dumps(
|
|
||||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
|
||||||
) + b"\n\n"
|
|
||||||
return
|
|
||||||
|
|
||||||
# Stream successful decode server response
|
|
||||||
async for line in decode_response.content:
|
|
||||||
yield line
|
|
||||||
yield b"data: [DONE]\n\n"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = {
|
|
||||||
"error": {"message": f"Stream processing error: {str(e)}"}
|
|
||||||
}
|
|
||||||
yield b"data: " + orjson.dumps(
|
|
||||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
|
||||||
) + b"\n\n"
|
|
||||||
finally:
|
|
||||||
if prefill_response is not None:
|
|
||||||
await prefill_response.release()
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_results(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-streaming case
|
|
||||||
result = await load_balancer.generate_request(request_data)
|
|
||||||
return ORJSONResponse(content=result)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
|
|||||||
Reference in New Issue
Block a user