diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index f6de3a884..2624e5939 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -50,6 +50,13 @@ class MiniLoadBalancer: self.prefill_servers = [p.url for p in prefill_configs] self.decode_servers = decode_servers + def add_prefill_server(self, new_prefill_config: PrefillConfig): + self.prefill_configs.append(new_prefill_config) + self.prefill_servers.append(new_prefill_config.url) + + def add_decode_server(self, new_decode_server: str): + self.decode_servers.append(new_decode_server) + def select_pair(self): # TODO: return some message instead of panic assert len(self.prefill_configs) > 0, "No prefill servers available" @@ -157,7 +164,7 @@ class MiniLoadBalancer: app = FastAPI() -load_balancer = None +load_balancer: Optional[MiniLoadBalancer] = None @app.get("/health") @@ -331,14 +338,14 @@ async def get_models(): @app.post("/register") async def register(obj: PDRegistryRequest): if obj.mode == "prefill": - load_balancer.prefill_configs.append( + load_balancer.add_prefill_server( PrefillConfig(obj.registry_url, obj.bootstrap_port) ) logger.info( f"Registered prefill server: {obj.registry_url} with bootstrap port: {obj.bootstrap_port}" ) elif obj.mode == "decode": - load_balancer.decode_servers.append(obj.registry_url) + load_balancer.add_decode_server(obj.registry_url) logger.info(f"Registered decode server: {obj.registry_url}") else: raise HTTPException(