[ModelRunner] Support embedding inputs (#916)
### What this PR does / why we need it?
- Adds support for passing prompt_embeds to LLM.generate as
```bash
llm.generate({"prompt_embeds": input_embeds}, sampling_params)
```
or
```bash
llm.generate(
[{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)
```
- Add `prompt_embeds` to examples
### How was this patch tested?
CI passed with new added/existing test.
and I have test with the example script in this pr, and the output seems
looks good:
```bash
[Single Inference Output]
------------------------------
The capital of France is Paris. Paris is the largest city in France and is
------------------------------
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3966.87it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.99it/s, est. speed input: 177.08 toks/s, output: 63.91 toks/s]
[Batch Inference Outputs]
------------------------------
Q1: Please tell me about the capital of France.
A1: The capital of France is Paris. It is located in the northern part of the
Q2: When is the day longest during the year?
A2: The day is longest during the year at the summer solstice. This typically occurs
Q3: Where is bigger, the moon or the sun?
A3: The sun is significantly bigger than the moon.
The sun has a diameter of
------------------------------
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
4
.github/workflows/vllm_ascend_test.yaml
vendored
4
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -136,12 +136,14 @@ jobs:
|
||||
pytest -sv tests/singlecard/test_camem.py
|
||||
# test_ascend_config.py should be ran separately because it will regenerate the global config many times.
|
||||
pytest -sv tests/singlecard/test_ascend_config.py
|
||||
pytest -sv tests/singlecard/test_prompt_embedding.py
|
||||
pytest -sv tests/singlecard/ \
|
||||
--ignore=tests/singlecard/test_offline_inference.py \
|
||||
--ignore=tests/singlecard/test_scheduler.py \
|
||||
--ignore=tests/singlecard/test_guided_decoding.py \
|
||||
--ignore=tests/singlecard/test_camem.py \
|
||||
--ignore=tests/singlecard/test_ascend_config.py
|
||||
--ignore=tests/singlecard/test_ascend_config.py \
|
||||
--ignore=tests/singlecard/test_prompt_embedding.py
|
||||
else
|
||||
pytest -sv tests/multicard/test_ilama_lora_tp2.py
|
||||
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error.
|
||||
|
||||
83
examples/prompt_embedding_inference.py
Normal file
83
examples/prompt_embedding_inference.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizer)
|
||||
from vllm import LLM
|
||||
|
||||
|
||||
def init_tokenizer_and_llm(model_name: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
embedding_layer = transformers_model.get_input_embeddings()
|
||||
llm = LLM(model=model_name, enable_prompt_embeds=True)
|
||||
return tokenizer, embedding_layer, llm
|
||||
|
||||
|
||||
def get_prompt_embeds(chat: list[dict[str,
|
||||
str]], tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
token_ids = tokenizer.apply_chat_template(chat,
|
||||
add_generation_prompt=True,
|
||||
return_tensors='pt')
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
chat = [{
|
||||
"role": "user",
|
||||
"content": "Please tell me about the capital of France."
|
||||
}]
|
||||
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt_embeds": prompt_embeds,
|
||||
})
|
||||
|
||||
print("\n[Single Inference Output]")
|
||||
print("-" * 30)
|
||||
for o in outputs:
|
||||
print(o.outputs[0].text)
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
chats = [[{
|
||||
"role": "user",
|
||||
"content": "Please tell me about the capital of France."
|
||||
}],
|
||||
[{
|
||||
"role": "user",
|
||||
"content": "When is the day longest during the year?"
|
||||
}],
|
||||
[{
|
||||
"role": "user",
|
||||
"content": "Where is bigger, the moon or the sun?"
|
||||
}]]
|
||||
|
||||
prompt_embeds_list = [
|
||||
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
||||
]
|
||||
|
||||
outputs = llm.generate([{
|
||||
"prompt_embeds": embeds
|
||||
} for embeds in prompt_embeds_list])
|
||||
|
||||
print("\n[Batch Inference Outputs]")
|
||||
print("-" * 30)
|
||||
for i, o in enumerate(outputs):
|
||||
print(f"Q{i+1}: {chats[i][0]['content']}")
|
||||
print(f"A{i+1}: {o.outputs[0].text}\n")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def main():
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
|
||||
single_prompt_inference(llm, tokenizer, embedding_layer)
|
||||
batch_prompt_inference(llm, tokenizer, embedding_layer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
259
tests/singlecard/test_prompt_embedding.py
Normal file
259
tests/singlecard/test_prompt_embedding.py
Normal file
@@ -0,0 +1,259 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/entrypoints/openai/test_completion_with_prompt_embeds.py
|
||||
#
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
from openai import BadRequestError
|
||||
from transformers import AutoConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
if not hasattr(EngineArgs, "enable_prompt_embeds"):
|
||||
pytest.skip("Not supported vllm version", allow_module_level=True)
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = snapshot_download("LLM-Research/Llama-3.2-1B-Instruct")
|
||||
|
||||
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args() -> list[str]:
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
# Prompt Embeds server args
|
||||
"--enable-prompt-embeds",
|
||||
"--no-enable-chunked-prefill",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["", "--disable-frontend-multiprocessing"])
|
||||
def server_with_prompt_embeds(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_with_prompt_embeds(server_with_prompt_embeds):
|
||||
async with server_with_prompt_embeds.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
def create_dummy_embeds(num_tokens: int = 5) -> str:
|
||||
"""Create dummy embeddings and return them as base64 encoded string."""
|
||||
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
|
||||
buffer = io.BytesIO()
|
||||
torch.save(dummy_embeds, buffer)
|
||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="Enable embedding input will fallback to v0, skip it")
|
||||
async def test_completions_with_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||
# Test case: Single prompt embeds input
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
# Test case: batch completion with prompt_embeds
|
||||
encoded_embeds2 = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
assert len(completion.choices) == 2
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert len(completion.choices[1].text) >= 1
|
||||
|
||||
# Test case: streaming with prompt_embeds
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
single_completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
single_output = single_completion.choices[0].text
|
||||
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
chunks = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
# Test case: batch streaming with prompt_embeds
|
||||
encoded_embeds2 = create_dummy_embeds()
|
||||
stream = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
chunks_stream_embeds: list[list[str]] = [[], []]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks_stream_embeds[chunk.choices[0].index].append(
|
||||
chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == 2
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert len(chunks_stream_embeds[0]) > 0
|
||||
assert len(chunks_stream_embeds[1]) > 0
|
||||
|
||||
# Test case: mixed text and prompt_embeds
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
completion_mixed = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="This is a prompt",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
assert len(completion.choices) == 2
|
||||
completion_text_only = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="This is a prompt",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion_embeds_only = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
# Embeddings responses should be handled first
|
||||
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
|
||||
0].text
|
||||
assert completion_mixed.choices[1].text == completion_text_only.choices[
|
||||
0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="Enable embedding input will fallback to v0, skip it")
|
||||
async def test_completions_errors_with_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
||||
# Test error case: invalid prompt_embeds
|
||||
with pytest.raises(BadRequestError):
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
prompt="",
|
||||
model=model_name,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={"prompt_embeds": "invalid_base64"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("VLLM_USE_V1") == "1",
|
||||
reason="Enable embedding input will fallback to v0, skip it")
|
||||
async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
|
||||
model_name: str):
|
||||
# Test case: Logprobs using prompt_embeds
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
logprobs=logprobs_arg,
|
||||
extra_body={"prompt_embeds": encoded_embeds})
|
||||
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) == 5
|
||||
assert len(logprobs.token_logprobs) == 5
|
||||
assert len(logprobs.top_logprobs) == 5
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) == 5
|
||||
|
||||
# Test case: Log probs with batch completion and prompt_embeds
|
||||
encoded_embeds2 = create_dummy_embeds()
|
||||
completion = await client_with_prompt_embeds.completions.create(
|
||||
model=model_name,
|
||||
prompt="", # Add empty prompt as required parameter
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=False,
|
||||
logprobs=logprobs_arg,
|
||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
||||
|
||||
assert len(completion.choices) == 2
|
||||
for choice in completion.choices:
|
||||
logprobs = choice.logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) == 5
|
||||
assert len(logprobs.token_logprobs) == 5
|
||||
assert len(logprobs.top_logprobs) == 5
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) == 5
|
||||
132
tests/utils.py
132
tests/utils.py
@@ -20,13 +20,143 @@
|
||||
import functools
|
||||
import os
|
||||
import signal
|
||||
from typing import Callable
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from typing_extensions import ParamSpec
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class RemoteOpenAIServer:
|
||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
*,
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
seed: Optional[int] = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
if auto_port:
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the port "
|
||||
"when `auto_port=True`.")
|
||||
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--port", str(get_open_port())
|
||||
]
|
||||
if seed is not None:
|
||||
if "--seed" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the seed "
|
||||
f"when `seed={seed}`.")
|
||||
|
||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
|
||||
self.show_hidden_metrics = \
|
||||
args.show_hidden_metrics_for_version is not None
|
||||
|
||||
# download the model before starting the server to avoid timeout
|
||||
is_local = os.path.isdir(model)
|
||||
if not is_local:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
model_config = engine_args.create_model_config()
|
||||
load_config = engine_args.create_load_config()
|
||||
|
||||
model_loader = get_model_loader(load_config)
|
||||
model_loader.download_model(model_config)
|
||||
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize cuda,
|
||||
# to be safe, we should use spawn method
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
max_wait_seconds = max_wait_seconds or 240
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=max_wait_seconds)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.terminate()
|
||||
try:
|
||||
self.proc.wait(8)
|
||||
except subprocess.TimeoutExpired:
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(url).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
# this exception can only be raised by requests.get,
|
||||
# which means the server is not ready yet.
|
||||
# the stack trace is not useful, so we suppress it
|
||||
# by using `raise from None`.
|
||||
result = self.proc.poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Server exited unexpectedly.") from None
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > timeout:
|
||||
raise RuntimeError(
|
||||
"Server failed to start in time.") from None
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
|
||||
def get_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return openai.OpenAI(
|
||||
base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_async_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def fork_new_process_for_each_test(
|
||||
f: Callable[_P, None]) -> Callable[_P, None]:
|
||||
"""Decorator to fork a new process for each test function.
|
||||
|
||||
@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_dp_group, get_pp_group
|
||||
from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
@@ -43,7 +43,8 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
||||
get_sampler)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
@@ -84,6 +85,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
||||
additional fields.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_types: Optional[torch.Tensor] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
@@ -103,6 +105,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
@@ -151,6 +154,7 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU):
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
@@ -188,6 +192,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
def simple_reinit(self):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.inputs_embeds = None # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.token_types[0].clear() # type: ignore
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
@@ -213,6 +218,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
# Input tokens and positions.
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
token_types: Optional[List[List[int]]] = None,
|
||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||
@@ -268,6 +274,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_tokens[seq_id].clear()
|
||||
self.inputs_embeds = inputs_embeds
|
||||
|
||||
if input_positions:
|
||||
self.input_positions = input_positions
|
||||
@@ -329,6 +336,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.inputs_embeds = inputs_embeds
|
||||
self.input_positions = input_positions or []
|
||||
self.token_types = token_types or []
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
@@ -368,6 +376,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
self.lora_index_mapping = []
|
||||
self.lora_prompt_mapping = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"InterDataForSeqGroup("
|
||||
f"request_id={self.request_id}, "
|
||||
f"seq_ids={self.seq_ids}, "
|
||||
f"is_prompt={self.is_prompt}, "
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"computed_block_nums={self.computed_block_nums}, "
|
||||
f"n_seqs={self.n_seqs}, "
|
||||
f"input_tokens={self.input_tokens}, "
|
||||
f"inputs_embeds.shape="
|
||||
f"{getattr(self.inputs_embeds, 'shape', None)}, "
|
||||
f"input_positions={self.input_positions}, "
|
||||
f"token_types={self.token_types}, "
|
||||
f"mrope_input_positions={self.mrope_input_positions}, "
|
||||
f"seq_lens={self.seq_lens}, "
|
||||
f"orig_seq_lens={self.orig_seq_lens}, "
|
||||
f"query_lens={self.query_lens}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"multi_modal_kwargs={self.multi_modal_kwargs}")
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
finished_requests_ids: Optional[List[str]] = None):
|
||||
@@ -492,11 +520,30 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
create on-device tensors.
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = [
|
||||
flatten_2d_lists(inter_data.input_tokens)
|
||||
for inter_data in self.inter_data_list
|
||||
]
|
||||
if not input_tokens:
|
||||
input_tokens = list[int]()
|
||||
inputs_embeds_list = list[torch.Tensor]()
|
||||
token_types = list[int]()
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
input_tokens.extend(cur_input_tokens)
|
||||
for cur_token_types in inter_data.token_types:
|
||||
token_types.extend(cur_token_types)
|
||||
if inter_data.inputs_embeds is not None:
|
||||
inputs_embeds_list.append(
|
||||
inter_data.inputs_embeds.to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device))
|
||||
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
if len(inputs_embeds_list) == 0:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device)
|
||||
assert len(inputs_embeds) == len(input_tokens)
|
||||
|
||||
if not input_tokens and inputs_embeds is None:
|
||||
# This may happen when all prefill requests hit
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
@@ -548,10 +595,6 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
else:
|
||||
graph_pad_size = -1
|
||||
|
||||
#print(f"before tensor input_tokens: {input_tokens}")
|
||||
#print(f"before tensor input_positions: {input_positions}")
|
||||
#print(f"before list seq_lens: {seq_lens}")
|
||||
input_tokens = flatten_2d_lists(input_tokens)
|
||||
if input_positions:
|
||||
input_positions = flatten_2d_lists(input_positions)
|
||||
if graph_pad_size != -1 and not is_prompt:
|
||||
@@ -563,6 +606,10 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
token_types_tensor = torch.tensor(token_types,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device) \
|
||||
if token_types else None
|
||||
if mrope_input_positions is not None:
|
||||
input_positions_tensor = torch.tensor(mrope_input_positions,
|
||||
dtype=torch.long,
|
||||
@@ -613,6 +660,8 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
inputs_embeds=inputs_embeds,
|
||||
token_types=token_types_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
@@ -645,13 +694,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
# Compute tokens.
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
# Fixme: this is for the version compatibility, remove this once vllm v0.8.5 does not be supported.
|
||||
if not hasattr(seq_data,
|
||||
"prompt_embeds") or seq_data.prompt_embeds is None:
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
prompt_embeds = None
|
||||
else:
|
||||
tokens = [0] * (seq_len - context_len)
|
||||
prompt_embeds = seq_data.get_token_embeddings(
|
||||
)[context_len:seq_len]
|
||||
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.inputs_embeds = prompt_embeds
|
||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||
inter_data.token_types[seq_idx].extend(
|
||||
token_types if token_types else [])
|
||||
@@ -1379,6 +1438,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
model_kwargs["attn_metadata"] = model_input.attn_metadata
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=model_input.inputs_embeds,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
@@ -1422,34 +1482,61 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
hidden_or_intermediate_states,
|
||||
)
|
||||
|
||||
if self.is_driver_worker:
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
assert isinstance(self.sampler, Sampler)
|
||||
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
|
||||
if model_input.inputs_embeds is not None:
|
||||
self.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the
|
||||
# latency from the start time of the driver worker to the end
|
||||
# time of the driver worker. The model forward time will then
|
||||
# end up covering the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if model_input.inputs_embeds is not None:
|
||||
if self.is_driver_worker:
|
||||
sampled = broadcast_tensor_dict(
|
||||
{"token_ids": output.sampled_token_ids})
|
||||
else:
|
||||
sampled = broadcast_tensor_dict()
|
||||
if sampled["token_ids"] is not None:
|
||||
sampled_token_embeds = self.model.get_input_embeddings(
|
||||
sampled["token_ids"].squeeze(1))
|
||||
if self.is_driver_worker:
|
||||
self.sampler.include_gpu_probs_tensor = \
|
||||
orig_include_gpu_probs
|
||||
|
||||
output.sampled_token_embeds = sampled_token_embeds
|
||||
|
||||
for token_embed, sequence_group_output in zip(
|
||||
output.sampled_token_embeds, output.outputs):
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
sequence_group_output.samples[
|
||||
0].output_embed = token_embed
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
|
||||
Reference in New Issue
Block a user