[PD] Fix prefill_servers in mini_lb (#6527)
This commit is contained in:
@@ -50,6 +50,13 @@ class MiniLoadBalancer:
|
|||||||
self.prefill_servers = [p.url for p in prefill_configs]
|
self.prefill_servers = [p.url for p in prefill_configs]
|
||||||
self.decode_servers = decode_servers
|
self.decode_servers = decode_servers
|
||||||
|
|
||||||
|
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
||||||
|
self.prefill_configs.append(new_prefill_config)
|
||||||
|
self.prefill_servers.append(new_prefill_config.url)
|
||||||
|
|
||||||
|
def add_decode_server(self, new_decode_server: str):
|
||||||
|
self.decode_servers.append(new_decode_server)
|
||||||
|
|
||||||
def select_pair(self):
|
def select_pair(self):
|
||||||
# TODO: return some message instead of panic
|
# TODO: return some message instead of panic
|
||||||
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
assert len(self.prefill_configs) > 0, "No prefill servers available"
|
||||||
@@ -157,7 +164,7 @@ class MiniLoadBalancer:
|
|||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
load_balancer = None
|
load_balancer: Optional[MiniLoadBalancer] = None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@@ -331,14 +338,14 @@ async def get_models():
|
|||||||
@app.post("/register")
|
@app.post("/register")
|
||||||
async def register(obj: PDRegistryRequest):
|
async def register(obj: PDRegistryRequest):
|
||||||
if obj.mode == "prefill":
|
if obj.mode == "prefill":
|
||||||
load_balancer.prefill_configs.append(
|
load_balancer.add_prefill_server(
|
||||||
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
PrefillConfig(obj.registry_url, obj.bootstrap_port)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}"
|
||||||
)
|
)
|
||||||
elif obj.mode == "decode":
|
elif obj.mode == "decode":
|
||||||
load_balancer.decode_servers.append(obj.registry_url)
|
load_balancer.add_decode_server(obj.registry_url)
|
||||||
logger.info(f"Registered decode server: {obj.registry_url}")
|
logger.info(f"Registered decode server: {obj.registry_url}")
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|||||||
Reference in New Issue
Block a user