[PD] Add simple unit test for disaggregation feature (#5654)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-05-11 13:35:27 +08:00
committed by GitHub
parent a823c6e834
commit 31d1f6e7f4
5 changed files with 241 additions and 0 deletions

View File

@@ -478,6 +478,81 @@ def popen_launch_server(
raise TimeoutError("Server failed to start within the timeout period.")
def popen_launch_pd_server(
model: str,
base_url: str,
timeout: float,
api_key: Optional[str] = None,
other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
):
_, host, port = base_url.split(":")
host = host[2:]
command = "sglang.launch_server"
command = [
"python3",
"-m",
command,
"--model-path",
model,
*[str(x) for x in other_args],
]
command.extend(
[
"--host",
host,
"--port",
port,
]
)
if api_key:
command += ["--api-key", api_key]
print(f"command={' '.join(command)}")
if return_stdout_stderr:
process = subprocess.Popen(
command,
stdout=return_stdout_stderr[0],
stderr=return_stdout_stderr[1],
env=env,
text=True,
)
else:
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
start_time = time.time()
with requests.Session() as session:
while time.time() - start_time < timeout:
try:
headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
response = session.get(
f"{base_url}/health",
headers=headers,
)
if response.status_code == 200:
return process
except requests.RequestException:
pass
return_code = process.poll()
if return_code is not None:
raise Exception(f"Server unexpectedly exits ({return_code=}).")
time.sleep(10)
kill_process_tree(process.pid)
raise TimeoutError("Server failed to start within the timeout period.")
def run_with_timeout(
func: Callable,
args: tuple = (),