[feature] Prompt Embeddings Support for v1 Engine (#3026)
### What this PR does / why we need it? this PR based on [19746](https://github.com/vllm-project/vllm/issues/19746), support Prompt Embeddings for v1 engine on NPU ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? ```python python examples/prompt_embed_inference.py ``` - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.1 --------- Signed-off-by: jesse <szxfml@gmail.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -88,6 +88,7 @@ jobs:
|
|||||||
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
|
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
|
||||||
# the test separately.
|
# the test separately.
|
||||||
|
|
||||||
|
pytest -sv tests/e2e/singlecard/test_completion_with_prompt_embeds.py
|
||||||
pytest -sv tests/e2e/singlecard/test_aclgraph.py
|
pytest -sv tests/e2e/singlecard/test_aclgraph.py
|
||||||
pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py
|
pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py
|
||||||
pytest -sv tests/e2e/singlecard/test_bge_model.py
|
pytest -sv tests/e2e/singlecard/test_bge_model.py
|
||||||
|
|||||||
97
examples/prompt_embed_inference.py
Normal file
97
examples/prompt_embed_inference.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Demonstrates how to generate prompt embeddings using
|
||||||
|
Hugging Face Transformers and use them as input to vLLM
|
||||||
|
for both single and batch inference.
|
||||||
|
|
||||||
|
Model: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
Note: This model is gated on Hugging Face Hub.
|
||||||
|
You must request access to use it:
|
||||||
|
https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
|
||||||
|
Requirements:
|
||||||
|
- vLLM
|
||||||
|
- transformers
|
||||||
|
|
||||||
|
Run:
|
||||||
|
python examples/prompt_embed_inference.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
def init_tokenizer_and_llm(model_name: str):
|
||||||
|
llm = LLM(model=model_name, enable_prompt_embeds=True)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
embedding_layer = transformers_model.get_input_embeddings()
|
||||||
|
return tokenizer, embedding_layer, llm
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_embeds(
|
||||||
|
chat: list[dict[str, str]],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
embedding_layer: torch.nn.Module,
|
||||||
|
):
|
||||||
|
token_ids = tokenizer.apply_chat_template(
|
||||||
|
chat, add_generation_prompt=True, return_tensors="pt"
|
||||||
|
)
|
||||||
|
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def single_prompt_inference(
|
||||||
|
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
||||||
|
):
|
||||||
|
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
|
||||||
|
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
{
|
||||||
|
"prompt_embeds": prompt_embeds,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n[Single Inference Output]")
|
||||||
|
print("-" * 30)
|
||||||
|
for o in outputs:
|
||||||
|
print(o.outputs[0].text)
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_prompt_inference(
|
||||||
|
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
|
||||||
|
):
|
||||||
|
chats = [
|
||||||
|
[{"role": "user", "content": "Please tell me about the capital of France."}],
|
||||||
|
[{"role": "user", "content": "When is the day longest during the year?"}],
|
||||||
|
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt_embeds_list = [
|
||||||
|
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
|
||||||
|
|
||||||
|
print("\n[Batch Inference Outputs]")
|
||||||
|
print("-" * 30)
|
||||||
|
for i, o in enumerate(outputs):
|
||||||
|
print(f"Q{i + 1}: {chats[i][0]['content']}")
|
||||||
|
print(f"A{i + 1}: {o.outputs[0].text}\n")
|
||||||
|
print("-" * 30)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
|
||||||
|
single_prompt_inference(llm, tokenizer, embedding_layer)
|
||||||
|
batch_prompt_inference(llm, tokenizer, embedding_layer)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
197
tests/e2e/singlecard/test_completion_with_prompt_embeds.py
Normal file
197
tests/e2e/singlecard/test_completion_with_prompt_embeds.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_embeds(chat, tokenizer, embedding_layer):
|
||||||
|
"""Convert chat messages to prompt embeddings."""
|
||||||
|
token_ids = tokenizer.apply_chat_template(chat,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
return_tensors='pt')
|
||||||
|
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
def test_single_prompt_embeds_inference(model_name):
|
||||||
|
"""Test single prompt inference with prompt embeddings."""
|
||||||
|
# Prepare prompt embeddings
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
embedding_layer = transformers_model.get_input_embeddings()
|
||||||
|
|
||||||
|
chat = [{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Please tell me about the capital of France."
|
||||||
|
}]
|
||||||
|
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||||
|
|
||||||
|
# Run inference with prompt embeddings
|
||||||
|
with VllmRunner(
|
||||||
|
model_name,
|
||||||
|
enable_prompt_embeds=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_runner:
|
||||||
|
outputs = vllm_runner.model.generate({
|
||||||
|
"prompt_embeds": prompt_embeds,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Verify output
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert len(outputs[0].outputs) > 0
|
||||||
|
assert len(outputs[0].outputs[0].text) > 0
|
||||||
|
print(f"\n[Single Inference Output]: {outputs[0].outputs[0].text}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
def test_batch_prompt_embeds_inference(model_name):
|
||||||
|
"""Test batch prompt inference with prompt embeddings."""
|
||||||
|
# Prepare prompt embeddings
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
embedding_layer = transformers_model.get_input_embeddings()
|
||||||
|
|
||||||
|
chats = [[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Please tell me about the capital of France."
|
||||||
|
}],
|
||||||
|
[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "When is the day longest during the year?"
|
||||||
|
}],
|
||||||
|
[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Where is bigger, the moon or the sun?"
|
||||||
|
}]]
|
||||||
|
|
||||||
|
prompt_embeds_list = [
|
||||||
|
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run batch inference with prompt embeddings
|
||||||
|
with VllmRunner(
|
||||||
|
model_name,
|
||||||
|
enable_prompt_embeds=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_runner:
|
||||||
|
outputs = vllm_runner.model.generate([{
|
||||||
|
"prompt_embeds": embeds
|
||||||
|
} for embeds in prompt_embeds_list])
|
||||||
|
|
||||||
|
# Verify outputs
|
||||||
|
assert len(outputs) == len(chats)
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
assert len(output.outputs) > 0
|
||||||
|
assert len(output.outputs[0].text) > 0
|
||||||
|
print(f"\nQ{i+1}: {chats[i][0]['content']}")
|
||||||
|
print(f"A{i+1}: {output.outputs[0].text}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
def test_prompt_embeds_with_aclgraph(model_name):
|
||||||
|
"""Test prompt embeddings with ACL graph enabled vs disabled."""
|
||||||
|
# Prepare prompt embeddings
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
embedding_layer = transformers_model.get_input_embeddings()
|
||||||
|
|
||||||
|
chat = [{"role": "user", "content": "What is the capital of China?"}]
|
||||||
|
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||||
|
|
||||||
|
# Run with ACL graph enabled (enforce_eager=False)
|
||||||
|
with VllmRunner(
|
||||||
|
model_name,
|
||||||
|
enable_prompt_embeds=True,
|
||||||
|
enforce_eager=False,
|
||||||
|
) as vllm_aclgraph_runner:
|
||||||
|
aclgraph_outputs = vllm_aclgraph_runner.model.generate({
|
||||||
|
"prompt_embeds":
|
||||||
|
prompt_embeds,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Run with ACL graph disabled (enforce_eager=True)
|
||||||
|
with VllmRunner(
|
||||||
|
model_name,
|
||||||
|
enable_prompt_embeds=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_eager_runner:
|
||||||
|
eager_outputs = vllm_eager_runner.model.generate({
|
||||||
|
"prompt_embeds":
|
||||||
|
prompt_embeds,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Verify both produce valid outputs
|
||||||
|
assert len(aclgraph_outputs) == 1
|
||||||
|
assert len(eager_outputs) == 1
|
||||||
|
assert len(aclgraph_outputs[0].outputs[0].text) > 0
|
||||||
|
assert len(eager_outputs[0].outputs[0].text) > 0
|
||||||
|
|
||||||
|
print("\n[ACL Graph Output]:", aclgraph_outputs[0].outputs[0].text)
|
||||||
|
print("[Eager Output]:", eager_outputs[0].outputs[0].text)
|
||||||
|
|
||||||
|
# Note: Outputs may differ slightly due to different execution paths,
|
||||||
|
# but both should be valid responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS)
|
||||||
|
def test_mixed_prompt_embeds_and_text(model_name):
|
||||||
|
"""Test mixed inputs with both prompt embeddings and text prompts."""
|
||||||
|
# Prepare prompt embeddings for first request
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||||
|
embedding_layer = transformers_model.get_input_embeddings()
|
||||||
|
|
||||||
|
chat = [{"role": "user", "content": "What is AI?"}]
|
||||||
|
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||||
|
|
||||||
|
# Prepare text prompt for second request
|
||||||
|
text_prompt = "What is machine learning?"
|
||||||
|
|
||||||
|
# Run inference with mixed inputs
|
||||||
|
with VllmRunner(
|
||||||
|
model_name,
|
||||||
|
enable_prompt_embeds=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_runner:
|
||||||
|
# Test prompt embeddings
|
||||||
|
embeds_output = vllm_runner.model.generate({
|
||||||
|
"prompt_embeds":
|
||||||
|
prompt_embeds,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Test text prompt
|
||||||
|
text_output = vllm_runner.model.generate(text_prompt)
|
||||||
|
|
||||||
|
# Verify both types of inputs work
|
||||||
|
assert len(embeds_output) == 1
|
||||||
|
assert len(text_output) == 1
|
||||||
|
assert len(embeds_output[0].outputs[0].text) > 0
|
||||||
|
assert len(text_output[0].outputs[0].text) > 0
|
||||||
|
|
||||||
|
print("\n[Prompt Embeds Output]:", embeds_output[0].outputs[0].text)
|
||||||
|
print("[Text Prompt Output]:", text_output[0].outputs[0].text)
|
||||||
@@ -72,7 +72,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
|
||||||
from vllm.utils.jsontree import json_map_leaves
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
@@ -346,11 +346,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||||
self.is_pooling_model = self.model_config.pooler_config is not None
|
self.is_pooling_model = self.model_config.pooler_config is not None
|
||||||
if self.is_multimodal_model:
|
self.enable_prompt_embeds = self.model_config.enable_prompt_embeds
|
||||||
self.inputs_embeds = torch.zeros(
|
if self.is_multimodal_model or self.enable_prompt_embeds:
|
||||||
(self.max_num_tokens, self.model_config.get_hidden_size()),
|
self.inputs_embeds = self._make_buffer(
|
||||||
|
self.max_num_tokens,
|
||||||
|
self.model_config.get_hidden_size(),
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device)
|
numpy=False)
|
||||||
|
self.is_token_ids = self._make_buffer(self.max_num_tokens,
|
||||||
|
dtype=torch.bool)
|
||||||
|
|
||||||
# Set up Attention
|
# Set up Attention
|
||||||
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
|
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
|
||||||
"index_topk")
|
"index_topk")
|
||||||
@@ -721,6 +726,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.requests[req_id] = CachedRequestState(
|
self.requests[req_id] = CachedRequestState(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
|
prompt_embeds=new_req_data.prompt_embeds,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
@@ -999,7 +1005,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.num_computed_tokens_cpu[index]
|
self.input_batch.num_computed_tokens_cpu[index]
|
||||||
num_scheduled_tokens = \
|
num_scheduled_tokens = \
|
||||||
scheduler_output.num_scheduled_tokens[req_id]
|
scheduler_output.num_scheduled_tokens[req_id]
|
||||||
num_prompt_tokens = len(req.prompt_token_ids)
|
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||||
|
req.prompt_token_ids, req.prompt_embeds)
|
||||||
|
|
||||||
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
||||||
prompt_part_len = max(0,
|
prompt_part_len = max(0,
|
||||||
@@ -1274,6 +1281,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||||
self.input_ids_cpu[:total_num_scheduled_tokens],
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
if self.is_multimodal_model or self.enable_prompt_embeds:
|
||||||
|
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Async scheduling case, where some decode requests from the previous
|
# Async scheduling case, where some decode requests from the previous
|
||||||
@@ -1301,6 +1311,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_ids[:total_num_scheduled_tokens].copy_(
|
self.input_ids[:total_num_scheduled_tokens].copy_(
|
||||||
self.input_ids_cpu[:total_num_scheduled_tokens],
|
self.input_ids_cpu[:total_num_scheduled_tokens],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
if self.is_multimodal_model or self.enable_prompt_embeds:
|
||||||
|
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
|
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||||
if num_commmon_tokens == 0:
|
if num_commmon_tokens == 0:
|
||||||
# No requests in common with the previous iteration
|
# No requests in common with the previous iteration
|
||||||
# So input_ids_cpu will have all the input ids.
|
# So input_ids_cpu will have all the input ids.
|
||||||
@@ -1314,6 +1327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
|
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
|
||||||
0],
|
0],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
||||||
return
|
return
|
||||||
# Upload the index tensors asynchronously
|
# Upload the index tensors asynchronously
|
||||||
# so the scatter can be non-blocking.
|
# so the scatter can be non-blocking.
|
||||||
@@ -1481,15 +1495,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# where M is the max_model_len.
|
# where M is the max_model_len.
|
||||||
token_indices = (positions_np +
|
token_indices = (positions_np +
|
||||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||||
|
token_indices_tensor = torch.from_numpy(token_indices)
|
||||||
# Prepare input_ids.
|
# Prepare input_ids.
|
||||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||||
# because torch.index_select is much faster than np.take for large
|
# because torch.index_select is much faster than np.take for large
|
||||||
# tensors.
|
# tensors.
|
||||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||||
0,
|
0,
|
||||||
torch.from_numpy(token_indices),
|
token_indices_tensor,
|
||||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||||
|
is_token_ids = self.input_batch.is_token_ids.flatten()
|
||||||
|
torch.index_select(
|
||||||
|
is_token_ids,
|
||||||
|
0,
|
||||||
|
token_indices_tensor,
|
||||||
|
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
|
||||||
|
|
||||||
|
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
|
||||||
|
# the InputBatch, we need to fill in the prompt embeds into the expected
|
||||||
|
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
|
||||||
|
if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or
|
||||||
|
self.enable_prompt_embeds):
|
||||||
|
output_idx = 0
|
||||||
|
for req_idx in range(num_reqs):
|
||||||
|
num_sched = num_scheduled_tokens[req_idx]
|
||||||
|
|
||||||
|
# Skip if this request doesn't have embeddings
|
||||||
|
if req_idx not in self.input_batch.req_prompt_embeds:
|
||||||
|
output_idx += num_sched
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip if no tokens scheduled
|
||||||
|
if num_sched <= 0:
|
||||||
|
output_idx += num_sched
|
||||||
|
continue
|
||||||
|
|
||||||
|
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
|
||||||
|
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||||
|
|
||||||
|
# Skip if trying to read beyond available embeddings
|
||||||
|
if start_pos >= req_embeds.shape[0]:
|
||||||
|
output_idx += num_sched
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Copy available embeddings
|
||||||
|
end_pos = start_pos + num_sched
|
||||||
|
actual_end = min(end_pos, req_embeds.shape[0])
|
||||||
|
actual_num_sched = actual_end - start_pos
|
||||||
|
|
||||||
|
if actual_num_sched > 0:
|
||||||
|
self.inputs_embeds.cpu[output_idx:output_idx +
|
||||||
|
actual_num_sched].copy_(
|
||||||
|
req_embeds[start_pos:actual_end]
|
||||||
|
)
|
||||||
|
|
||||||
|
output_idx += num_sched
|
||||||
|
|
||||||
self.query_start_loc_np[0] = 0
|
self.query_start_loc_np[0] = 0
|
||||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||||
@@ -1573,9 +1633,34 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO(woosuk): Avoid the copy. Optimize.
|
# TODO(woosuk): Avoid the copy. Optimize.
|
||||||
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
|
self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(
|
||||||
inputs_embeds)
|
inputs_embeds)
|
||||||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||||||
|
input_ids = None
|
||||||
|
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
|
||||||
|
# Get the input embeddings for the tokens that are not input embeds,
|
||||||
|
# then put them into the appropriate positions.
|
||||||
|
# TODO(qthequartermasterman): Since even when prompt embeds are
|
||||||
|
# enabled, (a) not all requests will use prompt embeds, and (b)
|
||||||
|
# after the initial prompt is processed, the rest of the generated
|
||||||
|
# tokens will be token ids, it is not desirable to have the
|
||||||
|
# embedding layer outside of the acl graph all the time. The v0
|
||||||
|
# engine avoids this by "double compiling" the acl graph, once
|
||||||
|
# with input_ids and again with inputs_embeds, for all num_tokens.
|
||||||
|
# If a batch only has token ids, then including the embedding layer
|
||||||
|
# in the acl graph will be more performant (like in the else case
|
||||||
|
# below).
|
||||||
|
token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens] \
|
||||||
|
.nonzero(as_tuple=False) \
|
||||||
|
.squeeze(1)
|
||||||
|
# Some tokens ids may need to become embeds
|
||||||
|
if token_ids_idx.numel() > 0:
|
||||||
|
token_ids = self.input_ids[token_ids_idx]
|
||||||
|
tokens_to_embeds = self.model.get_input_embeddings(
|
||||||
|
input_ids=token_ids)
|
||||||
|
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
|
||||||
|
|
||||||
|
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||||||
input_ids = None
|
input_ids = None
|
||||||
else:
|
else:
|
||||||
# For text-only models, we use token ids as input.
|
# For text-only models, we use token ids as input.
|
||||||
@@ -2404,6 +2489,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
self.input_batch.token_ids_cpu[req_idx,
|
self.input_batch.token_ids_cpu[req_idx,
|
||||||
start_idx:end_idx] = sampled_ids
|
start_idx:end_idx] = sampled_ids
|
||||||
|
self.input_batch.is_token_ids[req_idx,
|
||||||
|
start_idx:end_idx] = True
|
||||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||||
self.input_batch.num_tokens[req_idx] = end_idx
|
self.input_batch.num_tokens[req_idx] = end_idx
|
||||||
req_id = self.input_batch.req_ids[req_idx]
|
req_id = self.input_batch.req_ids[req_idx]
|
||||||
@@ -2729,7 +2816,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
||||||
|
elif self.enable_prompt_embeds:
|
||||||
|
input_ids = None
|
||||||
|
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
||||||
else:
|
else:
|
||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids[:num_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
@@ -3996,6 +4086,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Get metadata for this request.
|
# Get metadata for this request.
|
||||||
request = self.requests[req_id]
|
request = self.requests[req_id]
|
||||||
|
if request.prompt_token_ids is None:
|
||||||
|
# Prompt logprobs is incompatible with prompt embeddings
|
||||||
|
continue
|
||||||
num_prompt_tokens = len(request.prompt_token_ids)
|
num_prompt_tokens = len(request.prompt_token_ids)
|
||||||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
|||||||
MultiModalKwargsItems, PlaceholderRange)
|
MultiModalKwargsItems, PlaceholderRange)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||||
@@ -51,7 +52,7 @@ else:
|
|||||||
class CachedRequestState:
|
class CachedRequestState:
|
||||||
|
|
||||||
req_id: str
|
req_id: str
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: Optional[list[int]]
|
||||||
sampling_params: Optional[SamplingParams]
|
sampling_params: Optional[SamplingParams]
|
||||||
pooling_params: Optional[PoolingParams]
|
pooling_params: Optional[PoolingParams]
|
||||||
generator: Optional[torch.Generator]
|
generator: Optional[torch.Generator]
|
||||||
@@ -70,9 +71,11 @@ class CachedRequestState:
|
|||||||
mm_hashes: Optional[list[PlaceholderRange]] = None
|
mm_hashes: Optional[list[PlaceholderRange]] = None
|
||||||
|
|
||||||
lora_request: Optional[LoRARequest] = None
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||||
|
self.prompt_token_ids, self.prompt_embeds)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_tokens(self) -> int:
|
def num_tokens(self) -> int:
|
||||||
@@ -91,6 +94,10 @@ class CachedRequestState:
|
|||||||
|
|
||||||
def get_token_id(self, idx: int) -> int:
|
def get_token_id(self, idx: int) -> int:
|
||||||
if idx < self.num_prompt_tokens:
|
if idx < self.num_prompt_tokens:
|
||||||
|
if self.prompt_token_ids is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tried to access token index {idx}, but that token was "
|
||||||
|
"provided via prompt_embeds, and its ID is unknown.")
|
||||||
return self.prompt_token_ids[idx]
|
return self.prompt_token_ids[idx]
|
||||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||||
@@ -139,6 +146,14 @@ class InputBatch:
|
|||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
)
|
)
|
||||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||||
|
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
||||||
|
device="cpu",
|
||||||
|
dtype=bool,
|
||||||
|
pin_memory=False)
|
||||||
|
# Store prompt embeddings per request to avoid OOM from large upfront
|
||||||
|
# allocation if max_model_len is big.
|
||||||
|
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||||
|
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
|
||||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||||
@@ -345,15 +360,23 @@ class InputBatch:
|
|||||||
self.req_id_to_index[req_id] = req_index
|
self.req_id_to_index[req_id] = req_index
|
||||||
|
|
||||||
# Copy the prompt token ids and output token ids.
|
# Copy the prompt token ids and output token ids.
|
||||||
num_prompt_tokens = len(request.prompt_token_ids)
|
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||||
|
request.prompt_token_ids, request.prompt_embeds)
|
||||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||||
self.token_ids_cpu[
|
|
||||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
|
||||||
start_idx = num_prompt_tokens
|
start_idx = num_prompt_tokens
|
||||||
end_idx = start_idx + len(request.output_token_ids)
|
end_idx = start_idx + len(request.output_token_ids)
|
||||||
|
if request.prompt_token_ids is not None:
|
||||||
|
self.token_ids_cpu[
|
||||||
|
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||||
|
self.is_token_ids[req_index, :num_prompt_tokens] = True
|
||||||
|
else:
|
||||||
|
self.is_token_ids[req_index, :num_prompt_tokens] = False
|
||||||
|
if request.prompt_embeds is not None:
|
||||||
|
self.req_prompt_embeds[req_index] = request.prompt_embeds
|
||||||
self.token_ids_cpu[req_index,
|
self.token_ids_cpu[req_index,
|
||||||
start_idx:end_idx] = request.output_token_ids
|
start_idx:end_idx] = request.output_token_ids
|
||||||
# Number of token ids in token_ids_cpu.
|
self.is_token_ids[req_index, start_idx:end_idx] = True
|
||||||
|
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
|
||||||
# NOTE(woosuk): This may include spec decode tokens.
|
# NOTE(woosuk): This may include spec decode tokens.
|
||||||
self.num_tokens[req_index] = request.num_tokens
|
self.num_tokens[req_index] = request.num_tokens
|
||||||
# Number of tokens without spec decode tokens.
|
# Number of tokens without spec decode tokens.
|
||||||
@@ -553,6 +576,20 @@ class InputBatch:
|
|||||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||||
self.token_ids_cpu[i2, ...] = tmp
|
self.token_ids_cpu[i2, ...] = tmp
|
||||||
|
|
||||||
|
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
|
||||||
|
|
||||||
|
# Swap prompt embeddings if they exist
|
||||||
|
embeds_i1 = self.req_prompt_embeds.get(i1)
|
||||||
|
embeds_i2 = self.req_prompt_embeds.get(i2)
|
||||||
|
if embeds_i1 is not None:
|
||||||
|
self.req_prompt_embeds[i2] = embeds_i1
|
||||||
|
else:
|
||||||
|
self.req_prompt_embeds.pop(i2, None)
|
||||||
|
if embeds_i2 is not None:
|
||||||
|
self.req_prompt_embeds[i1] = embeds_i2
|
||||||
|
else:
|
||||||
|
self.req_prompt_embeds.pop(i1, None)
|
||||||
|
|
||||||
swap_dict_values(self.generators, i1, i2)
|
swap_dict_values(self.generators, i1, i2)
|
||||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||||
|
|
||||||
@@ -631,6 +668,11 @@ class InputBatch:
|
|||||||
num_tokens = self.num_tokens[last_req_index]
|
num_tokens = self.num_tokens[last_req_index]
|
||||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||||
last_req_index, :num_tokens]
|
last_req_index, :num_tokens]
|
||||||
|
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
|
||||||
|
last_req_index, :num_tokens]
|
||||||
|
if last_req_index in self.req_prompt_embeds:
|
||||||
|
self.req_prompt_embeds[
|
||||||
|
empty_index] = self.req_prompt_embeds.pop(last_req_index)
|
||||||
self.num_tokens[empty_index] = num_tokens
|
self.num_tokens[empty_index] = num_tokens
|
||||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||||
last_req_index]
|
last_req_index]
|
||||||
|
|||||||
Reference in New Issue
Block a user