2025-03-21 14:47:47 -07:00
|
|
|
"""
|
2025-04-08 09:42:34 -07:00
|
|
|
Minimal HTTP load balancer for prefill and decode servers for testing.
|
2025-03-21 14:47:47 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import asyncio
|
2025-05-07 01:12:57 +08:00
|
|
|
import dataclasses
|
|
|
|
|
import logging
|
2025-03-21 14:47:47 -07:00
|
|
|
import random
|
|
|
|
|
import urllib
|
|
|
|
|
from itertools import chain
|
2025-05-07 01:12:57 +08:00
|
|
|
from typing import List, Optional
|
2025-03-21 14:47:47 -07:00
|
|
|
|
|
|
|
|
import aiohttp
|
|
|
|
|
import orjson
|
|
|
|
|
import uvicorn
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
|
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
|
|
|
|
2025-05-07 01:12:57 +08:00
|
|
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
2025-03-21 14:47:47 -07:00
|
|
|
|
2025-05-07 01:12:57 +08:00
|
|
|
|
|
|
|
|
def setup_logger():
|
|
|
|
|
logger = logging.getLogger("pdlb")
|
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
|
|
|
formatter = logging.Formatter(
|
|
|
|
|
"[PDLB (Python)] %(asctime)s - %(levelname)s - %(message)s",
|
|
|
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
handler = logging.StreamHandler()
|
|
|
|
|
handler.setFormatter(formatter)
|
|
|
|
|
logger.addHandler(handler)
|
|
|
|
|
|
|
|
|
|
return logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
2025-04-26 00:30:47 +08:00
|
|
|
class PrefillConfig:
|
2025-05-07 01:12:57 +08:00
|
|
|
url: str
|
|
|
|
|
bootstrap_port: Optional[int] = None
|
2025-04-26 00:30:47 +08:00
|
|
|
|
|
|
|
|
|
2025-03-21 14:47:47 -07:00
|
|
|
class MiniLoadBalancer:
|
2025-04-26 00:30:47 +08:00
|
|
|
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
|
|
|
|
|
self.prefill_configs = prefill_configs
|
|
|
|
|
self.prefill_servers = [p.url for p in prefill_configs]
|
2025-03-21 14:47:47 -07:00
|
|
|
self.decode_servers = decode_servers
|
|
|
|
|
|
2025-05-26 10:38:41 +08:00
|
|
|
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
|
|
|
|
self.prefill_configs.append(new_prefill_config)
|
|
|
|
|
self.prefill_servers.append(new_prefill_config.url)
|
|
|
|
|
|
|
|
|
|
def add_decode_server(self, new_decode_server: str):
|
|
|
|
|
self.decode_servers.append(new_decode_server)
|
|
|
|
|
|
2025-03-21 14:47:47 -07:00
|
|
|
def select_pair(self):
|
2025-05-07 01:12:57 +08:00
|
|
|
# TODO: return some message instead of panic
|
|
|
|
|
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
|
|
|
|
assert len(self.decode_servers) > 0, "No decode servers available"
|
|
|
|
|
|
2025-04-26 00:30:47 +08:00
|
|
|
prefill_config = random.choice(self.prefill_configs)
|
|
|
|
|
decode_server = random.choice(self.decode_servers)
|
|
|
|
|
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
2025-03-21 14:47:47 -07:00
|
|
|
|
2025-04-08 09:42:34 -07:00
|
|
|
async def generate(
|
2025-04-21 16:10:58 +08:00
|
|
|
self, modified_request, prefill_server, decode_server, endpoint
|
2025-04-08 09:42:34 -07:00
|
|
|
) -> ORJSONResponse:
|
2025-04-21 16:10:58 +08:00
|
|
|
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
2025-03-21 14:47:47 -07:00
|
|
|
|
2025-04-19 14:42:57 +08:00
|
|
|
async with aiohttp.ClientSession(
|
|
|
|
|
timeout=aiohttp.ClientTimeout(
|
|
|
|
|
total=3600
|
|
|
|
|
) # Add timeout for request reliability
|
|
|
|
|
) as session:
|
2025-03-21 14:47:47 -07:00
|
|
|
tasks = [
|
2025-04-21 16:10:58 +08:00
|
|
|
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
|
|
|
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
2025-03-21 14:47:47 -07:00
|
|
|
]
|
2025-05-23 14:29:20 -07:00
|
|
|
|
2025-04-08 09:42:34 -07:00
|
|
|
# Wait for both responses to complete. Prefill should end first.
|
2025-05-23 14:29:20 -07:00
|
|
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
if "return_logprob" in modified_request:
|
|
|
|
|
|
|
|
|
|
prefill_json = await prefill_response.json()
|
|
|
|
|
ret_json = await decode_response.json()
|
|
|
|
|
|
|
|
|
|
# merge `meta_info.input_token_logprobs` from prefill to decode
|
|
|
|
|
if "meta_info" in ret_json:
|
|
|
|
|
if "input_token_logprobs" in ret_json["meta_info"]:
|
|
|
|
|
ret_json["meta_info"]["input_token_logprobs"] = (
|
|
|
|
|
prefill_json["meta_info"]["input_token_logprobs"]
|
|
|
|
|
+ ret_json["meta_info"]["input_token_logprobs"]
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
ret_json = await decode_response.json()
|
2025-03-21 14:47:47 -07:00
|
|
|
|
2025-04-08 09:42:34 -07:00
|
|
|
return ORJSONResponse(
|
2025-05-23 14:29:20 -07:00
|
|
|
content=ret_json,
|
2025-04-08 09:42:34 -07:00
|
|
|
status_code=decode_response.status,
|
|
|
|
|
)
|
2025-03-21 14:47:47 -07:00
|
|
|
|
2025-04-21 16:10:58 +08:00
|
|
|
async def generate_stream(
|
|
|
|
|
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
|
|
|
|
):
|
|
|
|
|
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
|
|
|
|
|
2025-04-08 09:42:34 -07:00
|
|
|
async def stream_results():
|
|
|
|
|
async with aiohttp.ClientSession(
|
|
|
|
|
timeout=aiohttp.ClientTimeout(
|
|
|
|
|
total=3600
|
|
|
|
|
) # Add timeout for request reliability
|
|
|
|
|
) as session:
|
2025-05-23 14:29:20 -07:00
|
|
|
# Create the tasks for both prefill and decode requests
|
|
|
|
|
tasks = [
|
2025-05-28 21:33:36 +08:00
|
|
|
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
|
|
|
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
2025-05-23 14:29:20 -07:00
|
|
|
]
|
|
|
|
|
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
|
|
|
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
|
|
|
|
|
|
|
|
|
if modified_request.get("return_logprob", False):
|
|
|
|
|
prefill_chunks = []
|
|
|
|
|
async for chunk in prefill_response.content:
|
|
|
|
|
prefill_chunks.append(chunk)
|
|
|
|
|
|
|
|
|
|
first_prefill_chunk = (
|
|
|
|
|
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
|
|
|
|
)
|
|
|
|
|
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
|
|
|
|
|
|
|
|
|
async for chunk in decode_response.content:
|
|
|
|
|
# Note: This is inefficient
|
|
|
|
|
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
|
|
|
|
decoded_chunk = chunk.decode("utf-8")
|
|
|
|
|
if (
|
|
|
|
|
decoded_chunk
|
|
|
|
|
and decoded_chunk.startswith("data:")
|
|
|
|
|
and "[DONE]" not in decoded_chunk
|
|
|
|
|
):
|
|
|
|
|
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
|
|
|
|
ret_json["meta_info"]["input_token_logprobs"] = (
|
|
|
|
|
first_prefill_chunk_json["meta_info"][
|
|
|
|
|
"input_token_logprobs"
|
|
|
|
|
]
|
|
|
|
|
+ ret_json["meta_info"]["input_token_logprobs"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
|
|
|
|
else:
|
|
|
|
|
yield chunk
|
|
|
|
|
else:
|
2025-04-08 09:42:34 -07:00
|
|
|
async for chunk in decode_response.content:
|
|
|
|
|
yield chunk
|
2025-03-21 14:47:47 -07:00
|
|
|
|
2025-04-08 09:42:34 -07:00
|
|
|
return StreamingResponse(
|
|
|
|
|
stream_results(),
|
|
|
|
|
media_type="text/event-stream",
|
|
|
|
|
)
|
2025-03-21 14:47:47 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI()
|
2025-05-26 10:38:41 +08:00
|
|
|
load_balancer: Optional[MiniLoadBalancer] = None
|
2025-03-21 14:47:47 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
|
|
async def health_check():
|
|
|
|
|
return Response(status_code=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/health_generate")
|
|
|
|
|
async def health_check():
|
|
|
|
|
prefill_servers, decode_servers = (
|
|
|
|
|
load_balancer.prefill_servers,
|
|
|
|
|
load_balancer.decode_servers,
|
|
|
|
|
)
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
|
# Create the tasks
|
|
|
|
|
tasks = []
|
|
|
|
|
for server in chain(prefill_servers, decode_servers):
|
|
|
|
|
tasks.append(session.post(f"{server}/health_generate"))
|
|
|
|
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
|
|
|
|
await response
|
|
|
|
|
return Response(status_code=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/flush_cache")
|
|
|
|
|
async def flush_cache():
|
|
|
|
|
prefill_servers, decode_servers = (
|
|
|
|
|
load_balancer.prefill_servers,
|
|
|
|
|
load_balancer.decode_servers,
|
|
|
|
|
)
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
|
# Create the tasks
|
|
|
|
|
tasks = []
|
|
|
|
|
for server in chain(prefill_servers, decode_servers):
|
|
|
|
|
tasks.append(session.post(f"{server}/flush_cache"))
|
|
|
|
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
|
|
|
|
await response
|
|
|
|
|
return Response(status_code=200)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/get_server_info")
|
|
|
|
|
async def get_server_info():
|
|
|
|
|
prefill_servers, decode_servers = (
|
|
|
|
|
load_balancer.prefill_servers,
|
|
|
|
|
load_balancer.decode_servers,
|
|
|
|
|
)
|
|
|
|
|
prefill_infos = []
|
|
|
|
|
decode_infos = []
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
|
for server in chain(prefill_servers):
|
|
|
|
|
server_info = await session.get(f"{server}/get_server_info")
|
|
|
|
|
prefill_infos.append(await server_info.json())
|
|
|
|
|
for server in chain(decode_servers):
|
|
|
|
|
server_info = await session.get(f"{server}/get_server_info")
|
|
|
|
|
decode_infos.append(await server_info.json())
|
|
|
|
|
|
|
|
|
|
return {"prefill": prefill_infos, "decode": decode_infos}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/get_model_info")
|
|
|
|
|
async def get_model_info():
|
|
|
|
|
# Dummy model information
|
|
|
|
|
model_info = {
|
|
|
|
|
"model_path": "/path/to/dummy/model",
|
|
|
|
|
"tokenizer_path": "/path/to/dummy/tokenizer",
|
|
|
|
|
"is_generation": True,
|
|
|
|
|
"preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
|
|
|
|
|
}
|
|
|
|
|
return ORJSONResponse(content=model_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/generate")
|
|
|
|
|
async def handle_generate_request(request_data: dict):
|
2025-04-26 00:30:47 +08:00
|
|
|
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
2025-03-21 14:47:47 -07:00
|
|
|
|
|
|
|
|
# Parse and transform prefill_server for bootstrap data
|
|
|
|
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
|
|
|
|
hostname = parsed_url.hostname
|
|
|
|
|
modified_request = request_data.copy()
|
2025-04-21 07:02:23 +08:00
|
|
|
|
|
|
|
|
batch_size = _get_request_batch_size(modified_request)
|
|
|
|
|
if batch_size is not None:
|
|
|
|
|
modified_request.update(
|
|
|
|
|
{
|
|
|
|
|
"bootstrap_host": [hostname] * batch_size,
|
2025-04-26 00:30:47 +08:00
|
|
|
"bootstrap_port": [bootstrap_port] * batch_size,
|
2025-04-21 07:02:23 +08:00
|
|
|
"bootstrap_room": [
|
|
|
|
|
_generate_bootstrap_room() for _ in range(batch_size)
|
|
|
|
|
],
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
modified_request.update(
|
|
|
|
|
{
|
|
|
|
|
"bootstrap_host": hostname,
|
2025-04-26 00:30:47 +08:00
|
|
|
"bootstrap_port": bootstrap_port,
|
2025-04-21 07:02:23 +08:00
|
|
|
"bootstrap_room": _generate_bootstrap_room(),
|
|
|
|
|
}
|
|
|
|
|
)
|
2025-03-21 14:47:47 -07:00
|
|
|
|
|
|
|
|
if request_data.get("stream", False):
|
2025-04-08 09:42:34 -07:00
|
|
|
return await load_balancer.generate_stream(
|
2025-04-21 21:39:18 +08:00
|
|
|
modified_request, prefill_server, decode_server, "generate"
|
2025-04-08 09:42:34 -07:00
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return await load_balancer.generate(
|
2025-04-21 21:39:18 +08:00
|
|
|
modified_request, prefill_server, decode_server, "generate"
|
2025-03-21 14:47:47 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-05-29 16:26:18 +08:00
|
|
|
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
2025-04-26 00:30:47 +08:00
|
|
|
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
2025-04-21 16:10:58 +08:00
|
|
|
|
|
|
|
|
# Parse and transform prefill_server for bootstrap data
|
|
|
|
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
|
|
|
|
hostname = parsed_url.hostname
|
|
|
|
|
modified_request = request_data.copy()
|
|
|
|
|
modified_request.update(
|
|
|
|
|
{
|
|
|
|
|
"bootstrap_host": hostname,
|
2025-04-26 00:30:47 +08:00
|
|
|
"bootstrap_port": bootstrap_port,
|
2025-05-29 16:26:18 +08:00
|
|
|
"bootstrap_room": _generate_bootstrap_room(),
|
2025-04-21 16:10:58 +08:00
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if request_data.get("stream", False):
|
|
|
|
|
return await load_balancer.generate_stream(
|
|
|
|
|
modified_request,
|
|
|
|
|
prefill_server,
|
|
|
|
|
decode_server,
|
2025-05-29 16:26:18 +08:00
|
|
|
endpoint=endpoint_name,
|
2025-04-21 16:10:58 +08:00
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return await load_balancer.generate(
|
|
|
|
|
modified_request,
|
|
|
|
|
prefill_server,
|
|
|
|
|
decode_server,
|
2025-05-29 16:26:18 +08:00
|
|
|
endpoint=endpoint_name,
|
2025-04-21 16:10:58 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-05-29 16:26:18 +08:00
|
|
|
@app.post("/v1/chat/completions")
|
|
|
|
|
async def handle_chat_completion_request(request_data: dict):
|
|
|
|
|
return await _forward_to_backend(request_data, "v1/chat/completions")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/v1/completions")
|
|
|
|
|
async def handle_completion_request(request_data: dict):
|
|
|
|
|
return await _forward_to_backend(request_data, "v1/completions")
|
|
|
|
|
|
|
|
|
|
|
2025-04-21 07:02:23 +08:00
|
|
|
def _generate_bootstrap_room():
|
|
|
|
|
return random.randint(0, 2**63 - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We may utilize `GenerateReqInput`'s logic later
|
|
|
|
|
def _get_request_batch_size(request):
|
|
|
|
|
if (text := request.get("text")) is not None:
|
|
|
|
|
return None if isinstance(text, str) else len(text)
|
|
|
|
|
if (input_ids := request.get("input_ids")) is not None:
|
|
|
|
|
return None if isinstance(input_ids[0], int) else len(input_ids)
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
2025-03-21 14:47:47 -07:00
|
|
|
@app.get("/v1/models")
|
|
|
|
|
async def get_models():
|
|
|
|
|
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
|
|
|
try:
|
|
|
|
|
response = await session.get(f"{prefill_server}/v1/models")
|
|
|
|
|
if response.status != 200:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=response.status,
|
|
|
|
|
detail=f"Prefill server error: Status {response.status}",
|
|
|
|
|
)
|
|
|
|
|
return ORJSONResponse(content=await response.json())
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
2025-05-07 01:12:57 +08:00
|
|
|
@app.post("/register")
|
|
|
|
|
async def register(obj: PDRegistryRequest):
|
|
|
|
|
if obj.mode == "prefill":
|
2025-05-26 10:38:41 +08:00
|
|
|
load_balancer.add_prefill_server(
|
2025-05-07 01:12:57 +08:00
|
|
|
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
|
|
|
|
)
|
|
|
|
|
elif obj.mode == "decode":
|
2025-05-26 10:38:41 +08:00
|
|
|
load_balancer.add_decode_server(obj.registry_url)
|
2025-05-07 01:12:57 +08:00
|
|
|
logger.info(f"Registered decode server: {obj.registry_url}")
|
|
|
|
|
else:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="Invalid mode. Must be either PREFILL or DECODE.",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"#Prefill servers: {len(load_balancer.prefill_configs)}, "
|
|
|
|
|
f"#Decode servers: {len(load_balancer.decode_servers)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return Response(status_code=200)
|
|
|
|
|
|
|
|
|
|
|
2025-04-26 00:30:47 +08:00
|
|
|
def run(prefill_configs, decode_addrs, host, port):
|
2025-03-21 14:47:47 -07:00
|
|
|
global load_balancer
|
2025-04-26 00:30:47 +08:00
|
|
|
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
2025-03-21 14:47:47 -07:00
|
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
|
|
|
|
parser.add_argument(
|
2025-05-07 01:12:57 +08:00
|
|
|
"--prefill", type=str, default=[], nargs="+", help="URLs for prefill servers"
|
2025-03-21 14:47:47 -07:00
|
|
|
)
|
2025-04-26 00:30:47 +08:00
|
|
|
parser.add_argument(
|
2025-05-07 01:12:57 +08:00
|
|
|
"--decode", type=str, default=[], nargs="+", help="URLs for decode servers"
|
2025-04-26 00:30:47 +08:00
|
|
|
)
|
2025-03-21 14:47:47 -07:00
|
|
|
parser.add_argument(
|
2025-05-07 01:12:57 +08:00
|
|
|
"--prefill-bootstrap-ports",
|
|
|
|
|
type=int,
|
|
|
|
|
nargs="+",
|
|
|
|
|
help="Bootstrap ports for prefill servers",
|
2025-03-21 14:47:47 -07:00
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
|
|
|
|
)
|
|
|
|
|
args = parser.parse_args()
|
2025-04-26 00:30:47 +08:00
|
|
|
|
2025-05-07 01:12:57 +08:00
|
|
|
bootstrap_ports = args.prefill_bootstrap_ports
|
|
|
|
|
if bootstrap_ports is None:
|
|
|
|
|
bootstrap_ports = [None] * len(args.prefill)
|
|
|
|
|
elif len(bootstrap_ports) == 1:
|
|
|
|
|
bootstrap_ports = bootstrap_ports * len(args.prefill)
|
2025-04-26 00:30:47 +08:00
|
|
|
else:
|
2025-05-07 01:12:57 +08:00
|
|
|
if len(bootstrap_ports) != len(args.prefill):
|
2025-04-26 00:30:47 +08:00
|
|
|
raise ValueError(
|
|
|
|
|
"Number of prefill URLs must match number of bootstrap ports"
|
|
|
|
|
)
|
|
|
|
|
|
2025-05-07 01:12:57 +08:00
|
|
|
prefill_configs = [
|
|
|
|
|
PrefillConfig(url, port) for url, port in zip(args.prefill, bootstrap_ports)
|
|
|
|
|
]
|
2025-04-26 00:30:47 +08:00
|
|
|
|
2025-05-07 01:12:57 +08:00
|
|
|
run(prefill_configs, args.decode, args.host, args.port)
|