diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index ebca01f41..a7502d1dc 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -7,6 +7,7 @@ import dataclasses import logging import random import urllib +from http import HTTPStatus from itertools import chain from typing import List, Optional @@ -262,14 +263,38 @@ async def get_server_info(): @app.get("/get_model_info") async def get_model_info(): - # Dummy model information - model_info = { - "model_path": "/path/to/dummy/model", - "tokenizer_path": "/path/to/dummy/tokenizer", - "is_generation": True, - "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128}, - } - return ORJSONResponse(content=model_info) + global load_balancer + + if not load_balancer or not load_balancer.prefill_servers: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail="There is no server registered", + ) + + target_server_url = load_balancer.prefill_servers[0] + endpoint_url = f"{target_server_url}/get_model_info" + + async with aiohttp.ClientSession() as session: + try: + async with session.get(endpoint_url) as response: + if response.status != 200: + error_text = await response.text() + raise HTTPException( + status_code=HTTPStatus.BAD_GATEWAY, + detail=( + f"Failed to get model info from {target_server_url}" + f"Status: {response.status}, Response: {error_text}" + ), + ) + + model_info_json = await response.json() + return ORJSONResponse(content=model_info_json) + + except aiohttp.ClientError as e: + raise HTTPException( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + detail=f"Failed to get model info from backend", + ) @app.post("/generate")