607 lines
19 KiB
Python
607 lines
19 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
# SPDX-License-Identifier: Apache-2.0
|
|||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|||
|
|
"""
|
|||
|
|
disagg_encoder_proxy.py
|
|||
|
|
|
|||
|
|
Proxy that routes OpenAI-compatible “/v1/chat/completions” requests to two
|
|||
|
|
clusters:
|
|||
|
|
• encode (multimodal feature extraction)
|
|||
|
|
• decode (language-model inference)
|
|||
|
|
|
|||
|
|
For MM input we:
|
|||
|
|
1. Extract *every* image/audio item.
|
|||
|
|
2. Fire N concurrent requests to the encoder cluster
|
|||
|
|
(one request per item, with **all text removed**).
|
|||
|
|
3. Wait for all of them to succeed.
|
|||
|
|
4. Forward the *original* request to a decode server.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import asyncio
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
import random
|
|||
|
|
import uuid
|
|||
|
|
from collections.abc import AsyncIterator
|
|||
|
|
|
|||
|
|
import aiohttp
|
|||
|
|
import uvicorn
|
|||
|
|
from fastapi import FastAPI, HTTPException, Request
|
|||
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|||
|
|
|
|||
|
|
###############################################################################
|
|||
|
|
# FastAPI app & global state
|
|||
|
|
###############################################################################
|
|||
|
|
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.DEBUG, format="%(asctime)s %(levelname)s: %(message)s"
|
|||
|
|
)
|
|||
|
|
logger = logging.getLogger("proxy")
|
|||
|
|
|
|||
|
|
app = FastAPI()
|
|||
|
|
encode_session: aiohttp.ClientSession | None = None
|
|||
|
|
prefill_session: aiohttp.ClientSession | None = None
|
|||
|
|
decode_session: aiohttp.ClientSession | None = None
|
|||
|
|
|
|||
|
|
###############################################################################
|
|||
|
|
# Utils
|
|||
|
|
###############################################################################
|
|||
|
|
|
|||
|
|
|
|||
|
|
MM_TYPES = {"image_url", "audio_url", "input_audio"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_mm_items(request_data: dict) -> list[dict]:
|
|||
|
|
"""
|
|||
|
|
Return *all* image/audio items that appear anywhere in `messages`.
|
|||
|
|
|
|||
|
|
Each returned dict looks like:
|
|||
|
|
{ "type": "image_url", "image_url": {...} }
|
|||
|
|
"""
|
|||
|
|
items: list[dict] = []
|
|||
|
|
for msg in request_data.get("messages", []):
|
|||
|
|
content = msg.get("content")
|
|||
|
|
if not isinstance(content, list):
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
for item in content:
|
|||
|
|
if item.get("type") in MM_TYPES:
|
|||
|
|
items.append(item)
|
|||
|
|
return items
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def fanout_encoder_primer(
|
|||
|
|
orig_request: dict,
|
|||
|
|
e_urls: list[str],
|
|||
|
|
req_id: str,
|
|||
|
|
) -> None:
|
|||
|
|
"""
|
|||
|
|
1. Build one request *per MM item* with all text removed.
|
|||
|
|
2. Send them concurrently to the encode cluster.
|
|||
|
|
3. Raise if any of them fails.
|
|||
|
|
"""
|
|||
|
|
logger.info("[%s] Processing multimodal items...", req_id)
|
|||
|
|
|
|||
|
|
mm_items = extract_mm_items(orig_request)
|
|||
|
|
if not mm_items:
|
|||
|
|
logger.info("[%s] No multimodal items, skipping encoder", req_id)
|
|||
|
|
return # nothing to do
|
|||
|
|
|
|||
|
|
logger.info("[%s] got %d multimodal items...", req_id, len(mm_items))
|
|||
|
|
|
|||
|
|
tasks = []
|
|||
|
|
|
|||
|
|
# Round-robin over encode servers to distribute load a bit
|
|||
|
|
url_cycle = (e_urls[i % len(e_urls)] for i in range(len(mm_items)))
|
|||
|
|
|
|||
|
|
for idx, (item, target_url) in enumerate(zip(mm_items, url_cycle)):
|
|||
|
|
# Derive a *child* request id: <parent>:<index>:<random-short>
|
|||
|
|
child_req_id = f"{req_id}:{idx}:{uuid.uuid4().hex[:6]}"
|
|||
|
|
headers = {"x-request-id": child_req_id}
|
|||
|
|
|
|||
|
|
encoder_req = {
|
|||
|
|
# You *may* need to keep additional fields
|
|||
|
|
"model": orig_request.get("model"),
|
|||
|
|
"messages": [
|
|||
|
|
{"role": "user", "content": [item]},
|
|||
|
|
],
|
|||
|
|
# Only need 1 token so the server actually runs the encoder path
|
|||
|
|
"max_tokens": 1,
|
|||
|
|
"stream": False,
|
|||
|
|
}
|
|||
|
|
tasks.append(
|
|||
|
|
encode_session.post(
|
|||
|
|
f"{target_url}/v1/chat/completions",
|
|||
|
|
json=encoder_req,
|
|||
|
|
headers=headers,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|||
|
|
|
|||
|
|
# Fail fast if any sub-request failed
|
|||
|
|
for idx, r in enumerate(results):
|
|||
|
|
if isinstance(r, Exception):
|
|||
|
|
logger.error(
|
|||
|
|
"[%s] Encoder request #%d raised exception: %s",
|
|||
|
|
req_id,
|
|||
|
|
idx,
|
|||
|
|
r,
|
|||
|
|
exc_info=r,
|
|||
|
|
)
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=502, detail=f"Encoder request failed: {str(r)}"
|
|||
|
|
)
|
|||
|
|
if r.status != 200:
|
|||
|
|
try:
|
|||
|
|
detail = await r.text()
|
|||
|
|
except Exception:
|
|||
|
|
detail = "<unable to read body>"
|
|||
|
|
logger.error(
|
|||
|
|
"[%s] Encoder request #%d returned status %s: %s",
|
|||
|
|
req_id,
|
|||
|
|
idx,
|
|||
|
|
r.status,
|
|||
|
|
detail,
|
|||
|
|
)
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=r.status,
|
|||
|
|
detail=f"Encoder request failed: {detail}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
"[%s] All %d encoder requests completed successfully", req_id, len(mm_items)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def maybe_prefill(
|
|||
|
|
req_data: dict,
|
|||
|
|
p_url: str,
|
|||
|
|
req_id: str,
|
|||
|
|
) -> dict:
|
|||
|
|
"""
|
|||
|
|
- Do prefill-only task if p_url exist;
|
|||
|
|
- Return modified request data with kv transfer params (for nixl connector)
|
|||
|
|
- Else, skip and return the original request data for decode
|
|||
|
|
"""
|
|||
|
|
if p_url:
|
|||
|
|
logger.info("[%s] Processing through prefill: %s", req_id, p_url)
|
|||
|
|
|
|||
|
|
prefill_response = await process_prefill_stage(req_data, p_url, req_id)
|
|||
|
|
# for nixl connector to facilitate kv transfer...
|
|||
|
|
prefill_response_json = await prefill_response.json()
|
|||
|
|
kv_transfer_params = prefill_response_json.get("kv_transfer_params", {})
|
|||
|
|
if kv_transfer_params:
|
|||
|
|
req_data["kv_transfer_params"] = kv_transfer_params
|
|||
|
|
|
|||
|
|
return req_data
|
|||
|
|
else:
|
|||
|
|
return req_data
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def process_prefill_stage(
|
|||
|
|
req_data: dict,
|
|||
|
|
p_url: str,
|
|||
|
|
req_id: str,
|
|||
|
|
) -> dict:
|
|||
|
|
"""Process request through Prefill stage and return kv_transfer_params"""
|
|||
|
|
logger.info("[%s] Sending prefill request to: %s", req_id, p_url)
|
|||
|
|
|
|||
|
|
prefill_request = req_data.copy()
|
|||
|
|
prefill_request["kv_transfer_params"] = {
|
|||
|
|
"do_remote_decode": True,
|
|||
|
|
"do_remote_prefill": False,
|
|||
|
|
"remote_engine_id": None,
|
|||
|
|
"remote_block_ids": None,
|
|||
|
|
"remote_host": None,
|
|||
|
|
"remote_port": None,
|
|||
|
|
}
|
|||
|
|
prefill_request["stream"] = False
|
|||
|
|
prefill_request["max_tokens"] = 1
|
|||
|
|
if "max_completion_tokens" in prefill_request:
|
|||
|
|
prefill_request["max_completion_tokens"] = 1
|
|||
|
|
if "stream_options" in prefill_request:
|
|||
|
|
del prefill_request["stream_options"]
|
|||
|
|
|
|||
|
|
headers = {"x-request-id": req_id}
|
|||
|
|
try:
|
|||
|
|
prefill_response = await prefill_session.post(
|
|||
|
|
f"{p_url}/v1/chat/completions", json=prefill_request, headers=headers
|
|||
|
|
)
|
|||
|
|
prefill_response.raise_for_status()
|
|||
|
|
|
|||
|
|
if prefill_response.status != 200:
|
|||
|
|
error_text = await prefill_response.text()
|
|||
|
|
logger.error(
|
|||
|
|
"[%s] Prefill request failed with status %d: %s",
|
|||
|
|
req_id,
|
|||
|
|
prefill_response.status,
|
|||
|
|
error_text,
|
|||
|
|
)
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=prefill_response.status,
|
|||
|
|
detail={"error": "Prefill request failed", "message": error_text},
|
|||
|
|
)
|
|||
|
|
logger.info("[%s] Prefill request completed successfully", req_id)
|
|||
|
|
|
|||
|
|
return prefill_response
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("Prefill processing failed: %s", str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=500,
|
|||
|
|
detail={"error": "Prefill processing error", "message": str(e)},
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
|
|||
|
|
###############################################################################
|
|||
|
|
# Middleware for request/response logging
|
|||
|
|
###############################################################################
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.middleware("http")
|
|||
|
|
async def log_requests(request: Request, call_next):
|
|||
|
|
"""Middleware to log all incoming requests and responses"""
|
|||
|
|
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
|||
|
|
|
|||
|
|
# Log incoming request
|
|||
|
|
logger.info(
|
|||
|
|
">>> [%s] %s %s from %s",
|
|||
|
|
req_id,
|
|||
|
|
request.method,
|
|||
|
|
request.url.path,
|
|||
|
|
request.client.host if request.client else "unknown",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Process request
|
|||
|
|
response = await call_next(request)
|
|||
|
|
|
|||
|
|
# Log response
|
|||
|
|
logger.info(
|
|||
|
|
"<<< [%s] %s %s completed with status %d",
|
|||
|
|
req_id,
|
|||
|
|
request.method,
|
|||
|
|
request.url.path,
|
|||
|
|
response.status_code,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return response
|
|||
|
|
except Exception as e:
|
|||
|
|
# Log errors
|
|||
|
|
logger.exception(
|
|||
|
|
"!!! [%s] %s %s failed with error: %s",
|
|||
|
|
req_id,
|
|||
|
|
request.method,
|
|||
|
|
request.url.path,
|
|||
|
|
str(e),
|
|||
|
|
)
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
###############################################################################
|
|||
|
|
# FastAPI lifecycle
|
|||
|
|
###############################################################################
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.on_event("startup")
|
|||
|
|
async def on_startup() -> None:
|
|||
|
|
global encode_session, prefill_session, decode_session
|
|||
|
|
timeout = aiohttp.ClientTimeout(total=100_000)
|
|||
|
|
connector = aiohttp.TCPConnector(limit=0, force_close=False)
|
|||
|
|
encode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
|||
|
|
if app.state.p_urls:
|
|||
|
|
# only setup if prefill instance(s) exist
|
|||
|
|
prefill_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
|||
|
|
decode_session = aiohttp.ClientSession(timeout=timeout, connector=connector)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.on_event("shutdown")
|
|||
|
|
async def on_shutdown() -> None:
|
|||
|
|
global encode_session, prefill_session, decode_session
|
|||
|
|
if encode_session:
|
|||
|
|
await encode_session.close()
|
|||
|
|
if prefill_session:
|
|||
|
|
await prefill_session.close()
|
|||
|
|
if decode_session:
|
|||
|
|
await decode_session.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
###############################################################################
|
|||
|
|
# Core forwarding
|
|||
|
|
###############################################################################
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def forward_non_stream(
|
|||
|
|
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
|
|||
|
|
) -> dict:
|
|||
|
|
try:
|
|||
|
|
# Step 1: Process through Encoder instance (if has MM input)
|
|||
|
|
await fanout_encoder_primer(req_data, e_urls, req_id)
|
|||
|
|
|
|||
|
|
# Step 2: Process through Prefill instance
|
|||
|
|
req_data = await maybe_prefill(req_data, p_url, req_id)
|
|||
|
|
|
|||
|
|
# Step 3: Process through Decode instance
|
|||
|
|
logger.info("[%s] Forwarding to decode: %s", req_id, d_url)
|
|||
|
|
headers = {"x-request-id": req_id}
|
|||
|
|
|
|||
|
|
# Non-streaming response
|
|||
|
|
async with decode_session.post(
|
|||
|
|
f"{d_url}/v1/chat/completions", json=req_data, headers=headers
|
|||
|
|
) as resp:
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
return await resp.json()
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.exception("[%s] Error in forward_non_stream: %s", req_id, str(e))
|
|||
|
|
raise HTTPException(status_code=500, detail=f"Proxy error: {str(e)}") from e
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def forward_stream(
|
|||
|
|
req_data: dict, req_id: str, e_urls: list[str], p_url: str, d_url: str
|
|||
|
|
) -> AsyncIterator[str]:
|
|||
|
|
try:
|
|||
|
|
# Step 1: Process through Encoder instance (if has MM input)
|
|||
|
|
await fanout_encoder_primer(req_data, e_urls, req_id)
|
|||
|
|
|
|||
|
|
# Step 2: Process through Prefill instance
|
|||
|
|
req_data = await maybe_prefill(req_data, p_url, req_id)
|
|||
|
|
|
|||
|
|
# Step 3: Process through Decode instance
|
|||
|
|
logger.info("[%s] Starting streaming from decode: %s", req_id, d_url)
|
|||
|
|
headers = {"x-request-id": req_id}
|
|||
|
|
|
|||
|
|
# Streaming response
|
|||
|
|
async with decode_session.post(
|
|||
|
|
f"{d_url}/v1/chat/completions",
|
|||
|
|
json=req_data,
|
|||
|
|
headers=headers,
|
|||
|
|
) as resp:
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
async for chunk in resp.content.iter_chunked(1024):
|
|||
|
|
if chunk:
|
|||
|
|
yield chunk.decode("utf-8", errors="ignore")
|
|||
|
|
|
|||
|
|
logger.info("[%s] Streaming completed", req_id)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
logger.exception("[%s] HTTPException in forward_stream", req_id)
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.exception("[%s] Error in forward_stream: %s", req_id, str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=500, detail=f"Proxy streaming error: {str(e)}"
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
|
|||
|
|
###############################################################################
|
|||
|
|
# Public routes
|
|||
|
|
###############################################################################
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/v1/chat/completions")
|
|||
|
|
async def chat_completions(request: Request):
|
|||
|
|
try:
|
|||
|
|
req_data = await request.json()
|
|||
|
|
req_id = request.headers.get("x-request-id", str(uuid.uuid4()))
|
|||
|
|
|
|||
|
|
e_urls = app.state.e_urls # we want the full list for fan-out
|
|||
|
|
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
|||
|
|
d_url = random.choice(app.state.d_urls)
|
|||
|
|
|
|||
|
|
is_streaming = req_data.get("stream", False)
|
|||
|
|
|
|||
|
|
if is_streaming:
|
|||
|
|
return StreamingResponse(
|
|||
|
|
forward_stream(req_data, req_id, e_urls, p_url, d_url),
|
|||
|
|
media_type="text/event-stream",
|
|||
|
|
)
|
|||
|
|
result = await forward_non_stream(req_data, req_id, e_urls, p_url, d_url)
|
|||
|
|
return JSONResponse(content=result)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.exception("Error in chat_completions endpoint: %s", str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=500, detail=f"Request processing error: {str(e)}"
|
|||
|
|
) from e
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/v1/models")
|
|||
|
|
async def list_models():
|
|||
|
|
async with decode_session.get(f"{app.state.d_urls[0]}/v1/models") as resp:
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
return await resp.json()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/health")
|
|||
|
|
async def health_check():
|
|||
|
|
async def healthy(urls):
|
|||
|
|
if not urls:
|
|||
|
|
return "empty"
|
|||
|
|
for u in urls:
|
|||
|
|
try:
|
|||
|
|
async with encode_session.get(f"{u}/health") as resp:
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
except Exception:
|
|||
|
|
return "unhealthy"
|
|||
|
|
return "healthy"
|
|||
|
|
|
|||
|
|
e_status, p_status, d_status = await asyncio.gather(
|
|||
|
|
healthy(app.state.e_urls), healthy(app.state.p_urls), healthy(app.state.d_urls)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
overall_healthy = all(
|
|||
|
|
status != "unhealthy" for status in (e_status, p_status, d_status)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
status_code = 200 if overall_healthy else 503
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
{
|
|||
|
|
"proxy": "healthy",
|
|||
|
|
"encode_cluster": e_status,
|
|||
|
|
"prefill_cluster": p_status,
|
|||
|
|
"decode_cluster": d_status,
|
|||
|
|
},
|
|||
|
|
status_code=status_code,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
###############################################################################
|
|||
|
|
# Simple profiler fan-out (unchanged except for sessions)
|
|||
|
|
###############################################################################
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _post_if_available(
|
|||
|
|
session: aiohttp.ClientSession,
|
|||
|
|
url: str,
|
|||
|
|
payload: dict,
|
|||
|
|
headers: dict,
|
|||
|
|
) -> dict | None:
|
|||
|
|
"""
|
|||
|
|
POST `payload` to `url`.
|
|||
|
|
|
|||
|
|
Returns
|
|||
|
|
-------
|
|||
|
|
• The decoded JSON body on success (2xx)
|
|||
|
|
• None if the endpoint does not exist (404)
|
|||
|
|
• Raises for anything else.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
resp = await session.post(url, json=payload, headers=headers)
|
|||
|
|
if resp.status == 404: # profiling disabled on that server
|
|||
|
|
logger.warning("Profiling endpoint missing on %s", url)
|
|||
|
|
return None
|
|||
|
|
resp.raise_for_status()
|
|||
|
|
return await resp.json(content_type=None)
|
|||
|
|
except aiohttp.ClientResponseError as exc:
|
|||
|
|
# Pass 404 through the branch above, re-raise everything else
|
|||
|
|
if exc.status == 404:
|
|||
|
|
logger.warning("Profiling endpoint missing on %s", url)
|
|||
|
|
return None
|
|||
|
|
raise
|
|||
|
|
except Exception:
|
|||
|
|
# Network errors etc.: propagate
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def _profile_cmd(cmd: str, payload: dict, e_url: str, p_url: str, d_url: str):
|
|||
|
|
"""
|
|||
|
|
Fire & forget to both clusters, tolerate 404.
|
|||
|
|
"""
|
|||
|
|
headers = {"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY', '')}"}
|
|||
|
|
|
|||
|
|
encode_task = _post_if_available(
|
|||
|
|
encode_session, f"{e_url}/{cmd}_profile", payload, headers
|
|||
|
|
)
|
|||
|
|
prefill_task = (
|
|||
|
|
_post_if_available(prefill_session, f"{p_url}/{cmd}_profile", payload, headers)
|
|||
|
|
if p_url is not None
|
|||
|
|
else asyncio.sleep(0)
|
|||
|
|
)
|
|||
|
|
decode_task = _post_if_available(
|
|||
|
|
decode_session, f"{d_url}/{cmd}_profile", payload, headers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
encode_res, prefill_res, decode_res = await asyncio.gather(
|
|||
|
|
encode_task, prefill_task, decode_task
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# If *all* clusters said “I don’t have that route”, surface an error
|
|||
|
|
if encode_res is prefill_res is decode_res is None:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=503,
|
|||
|
|
detail="Profiling endpoints are disabled on all clusters",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"encode": encode_res, # may be None
|
|||
|
|
"prefill": prefill_res, # may be None
|
|||
|
|
"decode": decode_res, # may be None
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/start_profile")
|
|||
|
|
async def start_profile(request: Request):
|
|||
|
|
body = await request.json()
|
|||
|
|
# TODO: handle multi urls properly
|
|||
|
|
e_url = random.choice(app.state.e_urls)
|
|||
|
|
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
|||
|
|
d_url = random.choice(app.state.d_urls)
|
|||
|
|
return await _profile_cmd("start", body, e_url, p_url, d_url)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post("/stop_profile")
|
|||
|
|
async def stop_profile(request: Request):
|
|||
|
|
body = await request.json()
|
|||
|
|
# TODO: handle multi urls properly
|
|||
|
|
e_url = random.choice(app.state.e_urls)
|
|||
|
|
p_url = random.choice(app.state.p_urls) if app.state.p_urls else None
|
|||
|
|
d_url = random.choice(app.state.d_urls)
|
|||
|
|
return await _profile_cmd("stop", body, e_url, p_url, d_url)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
parser = argparse.ArgumentParser()
|
|||
|
|
parser.add_argument("--host", default="0.0.0.0")
|
|||
|
|
parser.add_argument("--port", type=int, default=8000)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--encode-servers-urls",
|
|||
|
|
required=True,
|
|||
|
|
help='Comma-separated encode URLs ("http://e1:8001,http://e2:8001")',
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--prefill-servers-urls",
|
|||
|
|
required=True,
|
|||
|
|
help=(
|
|||
|
|
'Comma-separated prefill URLs ("http://p1:8003,http://p2:8004") ',
|
|||
|
|
'to enable E->P->D, set "disable" or "none" to enable E->PD',
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--decode-servers-urls",
|
|||
|
|
required=True,
|
|||
|
|
help='Comma-separated decode URLs ("http://d1:8005,http://d2:8006")',
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
app.state.e_urls = [
|
|||
|
|
u.strip() for u in args.encode_servers_urls.split(",") if u.strip()
|
|||
|
|
]
|
|||
|
|
app.state.d_urls = [
|
|||
|
|
u.strip() for u in args.decode_servers_urls.split(",") if u.strip()
|
|||
|
|
]
|
|||
|
|
# handle prefill instances
|
|||
|
|
if args.prefill_servers_urls.lower() in ("disable", "none", ""):
|
|||
|
|
app.state.p_urls = []
|
|||
|
|
logger.info(
|
|||
|
|
"Disaggregated prefill phase explicitly disabled by user. Running E + PD..."
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
app.state.p_urls = [
|
|||
|
|
u.strip() for u in args.prefill_servers_urls.split(",") if u.strip()
|
|||
|
|
]
|
|||
|
|
logger.info("Disaggregated prefill phase is enabled. Running E + P + D...")
|
|||
|
|
|
|||
|
|
logger.info("Proxy listening on %s:%s", args.host, args.port)
|
|||
|
|
logger.info("Encode servers: %s", app.state.e_urls)
|
|||
|
|
logger.info("Prefill instances %s", app.state.p_urls)
|
|||
|
|
logger.info("Decode servers: %s", app.state.d_urls)
|
|||
|
|
|
|||
|
|
uvicorn.run(
|
|||
|
|
app,
|
|||
|
|
host=args.host,
|
|||
|
|
port=args.port,
|
|||
|
|
log_level="info",
|
|||
|
|
loop="uvloop",
|
|||
|
|
access_log=True,
|
|||
|
|
)
|