[feature] Prompt Embeddings Support for v1 Engine (#3026)

### What this PR does / why we need it?
this PR based on
[19746](https://github.com/vllm-project/vllm/issues/19746), support
Prompt Embeddings for v1 engine on NPU

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

```python
python examples/prompt_embed_inference.py
```


- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.1

---------

Signed-off-by: jesse <szxfml@gmail.com>
This commit is contained in:
Song Zhixin
2025-10-30 17:15:57 +08:00
committed by GitHub
parent f6149f3894
commit 216fc0e8e4
5 changed files with 447 additions and 17 deletions

View File

@@ -72,7 +72,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import cdiv
from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
@@ -346,11 +346,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_multimodal_model = self.model_config.is_multimodal_model
self.is_pooling_model = self.model_config.pooler_config is not None
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.model_config.get_hidden_size()),
self.enable_prompt_embeds = self.model_config.enable_prompt_embeds
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds = self._make_buffer(
self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
device=self.device)
numpy=False)
self.is_token_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)
# Set up Attention
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
"index_topk")
@@ -721,6 +726,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
@@ -999,7 +1005,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req.prompt_token_ids, req.prompt_embeds)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
@@ -1274,6 +1281,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
return
# Async scheduling case, where some decode requests from the previous
@@ -1301,6 +1311,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids_cpu will have all the input ids.
@@ -1314,6 +1327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
0],
non_blocking=True)
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
@@ -1481,15 +1495,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
token_indices_tensor = torch.from_numpy(token_indices)
# Prepare input_ids.
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
token_indices_tensor,
out=self.input_ids_cpu[:total_num_scheduled_tokens])
is_token_ids = self.input_batch.is_token_ids.flatten()
torch.index_select(
is_token_ids,
0,
token_indices_tensor,
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
# the InputBatch, we need to fill in the prompt embeds into the expected
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or
self.enable_prompt_embeds):
output_idx = 0
for req_idx in range(num_reqs):
num_sched = num_scheduled_tokens[req_idx]
# Skip if this request doesn't have embeddings
if req_idx not in self.input_batch.req_prompt_embeds:
output_idx += num_sched
continue
# Skip if no tokens scheduled
if num_sched <= 0:
output_idx += num_sched
continue
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
# Skip if trying to read beyond available embeddings
if start_pos >= req_embeds.shape[0]:
output_idx += num_sched
continue
# Copy available embeddings
end_pos = start_pos + num_sched
actual_end = min(end_pos, req_embeds.shape[0])
actual_num_sched = actual_end - start_pos
if actual_num_sched > 0:
self.inputs_embeds.cpu[output_idx:output_idx +
actual_num_sched].copy_(
req_embeds[start_pos:actual_end]
)
output_idx += num_sched
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
@@ -1573,9 +1633,34 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(
inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
input_ids = None
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
# TODO(qthequartermasterman): Since even when prompt embeds are
# enabled, (a) not all requests will use prompt embeds, and (b)
# after the initial prompt is processed, the rest of the generated
# tokens will be token ids, it is not desirable to have the
# embedding layer outside of the acl graph all the time. The v0
# engine avoids this by "double compiling" the acl graph, once
# with input_ids and again with inputs_embeds, for all num_tokens.
# If a batch only has token ids, then including the embedding layer
# in the acl graph will be more performant (like in the else case
# below).
token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens] \
.nonzero(as_tuple=False) \
.squeeze(1)
# Some tokens ids may need to become embeds
if token_ids_idx.numel() > 0:
token_ids = self.input_ids[token_ids_idx]
tokens_to_embeds = self.model.get_input_embeddings(
input_ids=token_ids)
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
input_ids = None
else:
# For text-only models, we use token ids as input.
@@ -2404,6 +2489,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx,
start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
@@ -2729,7 +2816,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens):
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
@@ -3996,6 +4086,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Get metadata for this request.
request = self.requests[req_id]
if request.prompt_token_ids is None:
# Prompt logprobs is incompatible with prompt embeddings
continue
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)

View File

@@ -29,6 +29,7 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
@@ -51,7 +52,7 @@ else:
class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
prompt_token_ids: Optional[list[int]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
@@ -70,9 +71,11 @@ class CachedRequestState:
mm_hashes: Optional[list[PlaceholderRange]] = None
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@property
def num_tokens(self) -> int:
@@ -91,6 +94,10 @@ class CachedRequestState:
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx]
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
@@ -139,6 +146,14 @@ class InputBatch:
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
@@ -345,15 +360,23 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
@@ -553,6 +576,20 @@ class InputBatch:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
@@ -631,6 +668,11 @@ class InputBatch:
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[
empty_index] = self.req_prompt_embeds.pop(last_req_index)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]