# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/tests/utils.py # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import functools import os import signal import subprocess import sys import time from typing import Callable, Optional import openai import requests from typing_extensions import ParamSpec from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.model_executor.model_loader import get_model_loader from vllm.utils import FlexibleArgumentParser, get_open_port _P = ParamSpec("_P") class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key def __init__(self, model: str, vllm_serve_args: list[str], *, env_dict: Optional[dict[str, str]] = None, seed: Optional[int] = 0, auto_port: bool = True, max_wait_seconds: Optional[float] = None) -> None: if auto_port: if "-p" in vllm_serve_args or "--port" in vllm_serve_args: raise ValueError("You have manually specified the port " "when `auto_port=True`.") # Don't mutate the input args vllm_serve_args = vllm_serve_args + [ "--port", str(get_open_port()) ] if seed is not None: if "--seed" in vllm_serve_args: raise ValueError("You have manually specified the seed " f"when `seed={seed}`.") vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] parser = FlexibleArgumentParser( description="vLLM's remote OpenAI server.") parser = make_arg_parser(parser) args = parser.parse_args(["--model", model, *vllm_serve_args]) self.host = str(args.host or 'localhost') self.port = int(args.port) self.show_hidden_metrics = \ args.show_hidden_metrics_for_version is not None # download the model before starting the server to avoid timeout is_local = os.path.isdir(model) if not is_local: engine_args = AsyncEngineArgs.from_cli_args(args) model_config = engine_args.create_model_config() load_config = engine_args.create_load_config() model_loader = get_model_loader(load_config) model_loader.download_model(model_config) env = os.environ.copy() # the current process might initialize cuda, # to be safe, we should use spawn method env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' if env_dict is not None: env.update(env_dict) self.proc = subprocess.Popen( ["vllm", "serve", model, *vllm_serve_args], env=env, stdout=sys.stdout, stderr=sys.stderr, ) max_wait_seconds = max_wait_seconds or 240 self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.proc.terminate() try: self.proc.wait(8) except subprocess.TimeoutExpired: # force kill if needed self.proc.kill() def _wait_for_server(self, *, url: str, timeout: float): # run health check start = time.time() while True: try: if requests.get(url).status_code == 200: break except Exception: # this exception can only be raised by requests.get, # which means the server is not ready yet. # the stack trace is not useful, so we suppress it # by using `raise from None`. result = self.proc.poll() if result is not None and result != 0: raise RuntimeError("Server exited unexpectedly.") from None time.sleep(0.5) if time.time() - start > timeout: raise RuntimeError( "Server failed to start in time.") from None @property def url_root(self) -> str: return f"http://{self.host}:{self.port}" def url_for(self, *parts: str) -> str: return self.url_root + "/" + "/".join(parts) def get_client(self, **kwargs): if "timeout" not in kwargs: kwargs["timeout"] = 600 return openai.OpenAI( base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs, ) def get_async_client(self, **kwargs): if "timeout" not in kwargs: kwargs["timeout"] = 600 return openai.AsyncOpenAI(base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs) def fork_new_process_for_each_test( f: Callable[_P, None]) -> Callable[_P, None]: """Decorator to fork a new process for each test function. See https://github.com/vllm-project/vllm/issues/7053 for more details. """ @functools.wraps(f) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: # Make the process the leader of its own process group # to avoid sending SIGTERM to the parent process os.setpgrp() from _pytest.outcomes import Skipped pid = os.fork() print(f"Fork a new process to run a test {pid}") if pid == 0: try: f(*args, **kwargs) except Skipped as e: # convert Skipped to exit code 0 print(str(e)) os._exit(0) except Exception: import traceback traceback.print_exc() os._exit(1) else: os._exit(0) else: pgid = os.getpgid(pid) _pid, _exitcode = os.waitpid(pid, 0) # ignore SIGTERM signal itself old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) # kill all child processes os.killpg(pgid, signal.SIGTERM) # restore the signal handler signal.signal(signal.SIGTERM, old_signal_handler) assert _exitcode == 0, (f"function {f} failed when called with" f" args {args} and kwargs {kwargs}") return wrapper