[ModelRunner] Support embedding inputs (#916)
### What this PR does / why we need it?
- Adds support for passing prompt_embeds to LLM.generate as
```bash
llm.generate({"prompt_embeds": input_embeds}, sampling_params)
```
or
```bash
llm.generate(
[{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)
```
- Add `prompt_embeds` to examples
### How was this patch tested?
CI passed with new added/existing test.
and I have test with the example script in this pr, and the output seems
looks good:
```bash
[Single Inference Output]
------------------------------
The capital of France is Paris. Paris is the largest city in France and is
------------------------------
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3966.87it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.99it/s, est. speed input: 177.08 toks/s, output: 63.91 toks/s]
[Batch Inference Outputs]
------------------------------
Q1: Please tell me about the capital of France.
A1: The capital of France is Paris. It is located in the northern part of the
Q2: When is the day longest during the year?
A2: The day is longest during the year at the summer solstice. This typically occurs
Q3: Where is bigger, the moon or the sun?
A3: The sun is significantly bigger than the moon.
The sun has a diameter of
------------------------------
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
132
tests/utils.py
132
tests/utils.py
@@ -20,13 +20,143 @@
|
||||
import functools
|
||||
import os
|
||||
import signal
|
||||
from typing import Callable
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from typing_extensions import ParamSpec
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class RemoteOpenAIServer:
|
||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
*,
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
seed: Optional[int] = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
if auto_port:
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the port "
|
||||
"when `auto_port=True`.")
|
||||
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--port", str(get_open_port())
|
||||
]
|
||||
if seed is not None:
|
||||
if "--seed" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the seed "
|
||||
f"when `seed={seed}`.")
|
||||
|
||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
|
||||
self.show_hidden_metrics = \
|
||||
args.show_hidden_metrics_for_version is not None
|
||||
|
||||
# download the model before starting the server to avoid timeout
|
||||
is_local = os.path.isdir(model)
|
||||
if not is_local:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
model_config = engine_args.create_model_config()
|
||||
load_config = engine_args.create_load_config()
|
||||
|
||||
model_loader = get_model_loader(load_config)
|
||||
model_loader.download_model(model_config)
|
||||
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize cuda,
|
||||
# to be safe, we should use spawn method
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
max_wait_seconds = max_wait_seconds or 240
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=max_wait_seconds)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.terminate()
|
||||
try:
|
||||
self.proc.wait(8)
|
||||
except subprocess.TimeoutExpired:
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(url).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
# this exception can only be raised by requests.get,
|
||||
# which means the server is not ready yet.
|
||||
# the stack trace is not useful, so we suppress it
|
||||
# by using `raise from None`.
|
||||
result = self.proc.poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Server exited unexpectedly.") from None
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > timeout:
|
||||
raise RuntimeError(
|
||||
"Server failed to start in time.") from None
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
|
||||
def get_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return openai.OpenAI(
|
||||
base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_async_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def fork_new_process_for_each_test(
|
||||
f: Callable[_P, None]) -> Callable[_P, None]:
|
||||
"""Decorator to fork a new process for each test function.
|
||||
|
||||
Reference in New Issue
Block a user