[PD] Fix prefill_servers in mini_lb (#6527)

This commit is contained in:
wangxiyu191
2025-05-26 10:38:41 +08:00
committed by GitHub
parent e235be16fe
commit 8b33d8df90

View File

@@ -50,6 +50,13 @@ class MiniLoadBalancer:
self.prefill_servers = [p.url for p in prefill_configs]
self.decode_servers = decode_servers
def add_prefill_server(self, new_prefill_config: PrefillConfig):
self.prefill_configs.append(new_prefill_config)
self.prefill_servers.append(new_prefill_config.url)
def add_decode_server(self, new_decode_server: str):
self.decode_servers.append(new_decode_server)
def select_pair(self):
# TODO: return some message instead of panic
assert len(self.prefill_configs) > 0, "No prefill servers available"
@@ -157,7 +164,7 @@ class MiniLoadBalancer:
app = FastAPI()
load_balancer = None
load_balancer: Optional[MiniLoadBalancer] = None
@app.get("/health")
@@ -331,14 +338,14 @@ async def get_models():
@app.post("/register")
async def register(obj: PDRegistryRequest):
if obj.mode == "prefill":
load_balancer.prefill_configs.append(
load_balancer.add_prefill_server(
PrefillConfig(obj.registry_url, obj.bootstrap_port)
)
logger.info(
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
)
elif obj.mode == "decode":
load_balancer.decode_servers.append(obj.registry_url)
load_balancer.add_decode_server(obj.registry_url)
logger.info(f"Registered decode server: {obj.registry_url}")
else:
raise HTTPException(