Fix mini lb timeout issue (#9369)
This commit is contained in:
@@ -118,7 +118,13 @@ def main():
|
|||||||
lb_args = LBArgs.from_cli_args(args)
|
lb_args = LBArgs.from_cli_args(args)
|
||||||
|
|
||||||
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
|
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
|
||||||
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
|
run(
|
||||||
|
prefill_configs,
|
||||||
|
lb_args.decode_infos,
|
||||||
|
lb_args.host,
|
||||||
|
lb_args.port,
|
||||||
|
lb_args.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -50,10 +50,16 @@ class PrefillConfig:
|
|||||||
|
|
||||||
|
|
||||||
class MiniLoadBalancer:
|
class MiniLoadBalancer:
|
||||||
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
prefill_configs: List[PrefillConfig],
|
||||||
|
decode_servers: List[str],
|
||||||
|
timeout: int,
|
||||||
|
):
|
||||||
self.prefill_configs = prefill_configs
|
self.prefill_configs = prefill_configs
|
||||||
self.prefill_servers = [p.url for p in prefill_configs]
|
self.prefill_servers = [p.url for p in prefill_configs]
|
||||||
self.decode_servers = decode_servers
|
self.decode_servers = decode_servers
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
def add_prefill_server(self, new_prefill_config: PrefillConfig):
|
||||||
self.prefill_configs.append(new_prefill_config)
|
self.prefill_configs.append(new_prefill_config)
|
||||||
@@ -78,7 +84,7 @@ class MiniLoadBalancer:
|
|||||||
|
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
total=3600
|
total=self.timeout
|
||||||
) # Add timeout for request reliability
|
) # Add timeout for request reliability
|
||||||
) as session:
|
) as session:
|
||||||
tasks = [
|
tasks = [
|
||||||
@@ -117,7 +123,7 @@ class MiniLoadBalancer:
|
|||||||
async def stream_results():
|
async def stream_results():
|
||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
total=3600
|
total=self.timeout
|
||||||
) # Add timeout for request reliability
|
) # Add timeout for request reliability
|
||||||
) as session:
|
) as session:
|
||||||
# Create the tasks for both prefill and decode requests
|
# Create the tasks for both prefill and decode requests
|
||||||
@@ -401,9 +407,9 @@ async def register(obj: PDRegistryRequest):
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
def run(prefill_configs, decode_addrs, host, port):
|
def run(prefill_configs, decode_addrs, host, port, timeout):
|
||||||
global load_balancer
|
global load_balancer
|
||||||
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs, timeout=timeout)
|
||||||
uvicorn.run(app, host=host, port=port)
|
uvicorn.run(app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user