### What this PR does / why we need it?
- Adds support for passing prompt_embeds to LLM.generate as
```bash
llm.generate({"prompt_embeds": input_embeds}, sampling_params)
```
or
```bash
llm.generate(
[{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)
```
- Add `prompt_embeds` to examples
### How was this patch tested?
CI passed with new added/existing test.
and I have test with the example script in this pr, and the output seems
looks good:
```bash
[Single Inference Output]
------------------------------
The capital of France is Paris. Paris is the largest city in France and is
------------------------------
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3966.87it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.99it/s, est. speed input: 177.08 toks/s, output: 63.91 toks/s]
[Batch Inference Outputs]
------------------------------
Q1: Please tell me about the capital of France.
A1: The capital of France is Paris. It is located in the northern part of the
Q2: When is the day longest during the year?
A2: The day is longest during the year at the summer solstice. This typically occurs
Q3: Where is bigger, the moon or the sun?
A3: The sun is significantly bigger than the moon.
The sun has a diameter of
------------------------------
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
260 lines
9.5 KiB
Python
260 lines
9.5 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm/tests/entrypoints/openai/test_completion_with_prompt_embeds.py
|
|
#
|
|
import base64
|
|
import io
|
|
import os
|
|
|
|
import openai # use the official client for correctness check
|
|
import pytest
|
|
import pytest_asyncio
|
|
import torch
|
|
from modelscope import snapshot_download # type: ignore
|
|
from openai import BadRequestError
|
|
from transformers import AutoConfig
|
|
from vllm.engine.arg_utils import EngineArgs
|
|
|
|
from tests.utils import RemoteOpenAIServer
|
|
|
|
if not hasattr(EngineArgs, "enable_prompt_embeds"):
|
|
pytest.skip("Not supported vllm version", allow_module_level=True)
|
|
|
|
# any model with a chat template should work here
|
|
MODEL_NAME = snapshot_download("LLM-Research/Llama-3.2-1B-Instruct")
|
|
|
|
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def default_server_args() -> list[str]:
|
|
return [
|
|
# use half precision for speed and memory savings in CI environment
|
|
"--dtype",
|
|
"bfloat16",
|
|
"--max-model-len",
|
|
"8192",
|
|
"--max-num-seqs",
|
|
"128",
|
|
"--enforce-eager",
|
|
# Prompt Embeds server args
|
|
"--enable-prompt-embeds",
|
|
"--no-enable-chunked-prefill",
|
|
]
|
|
|
|
|
|
@pytest.fixture(scope="module",
|
|
params=["", "--disable-frontend-multiprocessing"])
|
|
def server_with_prompt_embeds(default_server_args, request):
|
|
if request.param:
|
|
default_server_args.append(request.param)
|
|
|
|
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
|
yield remote_server
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client_with_prompt_embeds(server_with_prompt_embeds):
|
|
async with server_with_prompt_embeds.get_async_client() as async_client:
|
|
yield async_client
|
|
|
|
|
|
def create_dummy_embeds(num_tokens: int = 5) -> str:
|
|
"""Create dummy embeddings and return them as base64 encoded string."""
|
|
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
|
|
buffer = io.BytesIO()
|
|
torch.save(dummy_embeds, buffer)
|
|
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
@pytest.mark.skipif(
|
|
os.getenv("VLLM_USE_V1") == "1",
|
|
reason="Enable embedding input will fallback to v0, skip it")
|
|
async def test_completions_with_prompt_embeds(
|
|
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
|
# Test case: Single prompt embeds input
|
|
encoded_embeds = create_dummy_embeds()
|
|
completion = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="", # Add empty prompt as required parameter
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
extra_body={"prompt_embeds": encoded_embeds})
|
|
assert len(completion.choices[0].text) >= 1
|
|
assert completion.choices[0].prompt_logprobs is None
|
|
|
|
# Test case: batch completion with prompt_embeds
|
|
encoded_embeds2 = create_dummy_embeds()
|
|
completion = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="", # Add empty prompt as required parameter
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
|
assert len(completion.choices) == 2
|
|
assert len(completion.choices[0].text) >= 1
|
|
assert len(completion.choices[1].text) >= 1
|
|
|
|
# Test case: streaming with prompt_embeds
|
|
encoded_embeds = create_dummy_embeds()
|
|
single_completion = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="", # Add empty prompt as required parameter
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
extra_body={"prompt_embeds": encoded_embeds})
|
|
single_output = single_completion.choices[0].text
|
|
|
|
stream = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="", # Add empty prompt as required parameter
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
stream=True,
|
|
extra_body={"prompt_embeds": encoded_embeds})
|
|
chunks = []
|
|
finish_reason_count = 0
|
|
async for chunk in stream:
|
|
chunks.append(chunk.choices[0].text)
|
|
if chunk.choices[0].finish_reason is not None:
|
|
finish_reason_count += 1
|
|
assert finish_reason_count == 1
|
|
assert chunk.choices[0].finish_reason == "length"
|
|
assert chunk.choices[0].text
|
|
assert "".join(chunks) == single_output
|
|
|
|
# Test case: batch streaming with prompt_embeds
|
|
encoded_embeds2 = create_dummy_embeds()
|
|
stream = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="", # Add empty prompt as required parameter
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
stream=True,
|
|
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
|
chunks_stream_embeds: list[list[str]] = [[], []]
|
|
finish_reason_count = 0
|
|
async for chunk in stream:
|
|
chunks_stream_embeds[chunk.choices[0].index].append(
|
|
chunk.choices[0].text)
|
|
if chunk.choices[0].finish_reason is not None:
|
|
finish_reason_count += 1
|
|
assert finish_reason_count == 2
|
|
assert chunk.choices[0].finish_reason == "length"
|
|
assert chunk.choices[0].text
|
|
assert len(chunks_stream_embeds[0]) > 0
|
|
assert len(chunks_stream_embeds[1]) > 0
|
|
|
|
# Test case: mixed text and prompt_embeds
|
|
encoded_embeds = create_dummy_embeds()
|
|
completion_mixed = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="This is a prompt",
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
extra_body={"prompt_embeds": encoded_embeds})
|
|
assert len(completion.choices) == 2
|
|
completion_text_only = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="This is a prompt",
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
)
|
|
completion_embeds_only = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="",
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
extra_body={"prompt_embeds": encoded_embeds})
|
|
# Embeddings responses should be handled first
|
|
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
|
|
0].text
|
|
assert completion_mixed.choices[1].text == completion_text_only.choices[
|
|
0].text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
@pytest.mark.skipif(
|
|
os.getenv("VLLM_USE_V1") == "1",
|
|
reason="Enable embedding input will fallback to v0, skip it")
|
|
async def test_completions_errors_with_prompt_embeds(
|
|
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
|
# Test error case: invalid prompt_embeds
|
|
with pytest.raises(BadRequestError):
|
|
await client_with_prompt_embeds.completions.create(
|
|
prompt="",
|
|
model=model_name,
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
extra_body={"prompt_embeds": "invalid_base64"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
@pytest.mark.skipif(
|
|
os.getenv("VLLM_USE_V1") == "1",
|
|
reason="Enable embedding input will fallback to v0, skip it")
|
|
async def test_completions_with_logprobs_and_prompt_embeds(
|
|
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
|
|
model_name: str):
|
|
# Test case: Logprobs using prompt_embeds
|
|
encoded_embeds = create_dummy_embeds()
|
|
completion = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="", # Add empty prompt as required parameter
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
echo=False,
|
|
logprobs=logprobs_arg,
|
|
extra_body={"prompt_embeds": encoded_embeds})
|
|
|
|
logprobs = completion.choices[0].logprobs
|
|
assert logprobs is not None
|
|
assert len(logprobs.text_offset) == 5
|
|
assert len(logprobs.token_logprobs) == 5
|
|
assert len(logprobs.top_logprobs) == 5
|
|
for top_logprobs in logprobs.top_logprobs[1:]:
|
|
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
|
|
assert len(logprobs.tokens) == 5
|
|
|
|
# Test case: Log probs with batch completion and prompt_embeds
|
|
encoded_embeds2 = create_dummy_embeds()
|
|
completion = await client_with_prompt_embeds.completions.create(
|
|
model=model_name,
|
|
prompt="", # Add empty prompt as required parameter
|
|
max_tokens=5,
|
|
temperature=0.0,
|
|
echo=False,
|
|
logprobs=logprobs_arg,
|
|
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
|
|
|
assert len(completion.choices) == 2
|
|
for choice in completion.choices:
|
|
logprobs = choice.logprobs
|
|
assert logprobs is not None
|
|
assert len(logprobs.text_offset) == 5
|
|
assert len(logprobs.token_logprobs) == 5
|
|
assert len(logprobs.top_logprobs) == 5
|
|
for top_logprobs in logprobs.top_logprobs[1:]:
|
|
assert max(logprobs_arg,
|
|
1) <= len(top_logprobs) <= logprobs_arg + 1
|
|
assert len(logprobs.tokens) == 5
|