[ModelRunner] Support embedding inputs (#916)

### What this PR does / why we need it?
- Adds support for passing prompt_embeds to LLM.generate as
```bash
llm.generate({"prompt_embeds": input_embeds}, sampling_params)
```
or
```bash
llm.generate(
    [{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)
```
- Add `prompt_embeds` to examples

### How was this patch tested?
CI passed with new added/existing test.
and I have test with the example script in this pr, and the output seems
looks good:
```bash

[Single Inference Output]
------------------------------
The capital of France is Paris. Paris is the largest city in France and is
------------------------------
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3966.87it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.99it/s, est. speed input: 177.08 toks/s, output: 63.91 toks/s]

[Batch Inference Outputs]
------------------------------
Q1: Please tell me about the capital of France.
A1: The capital of France is Paris. It is located in the northern part of the

Q2: When is the day longest during the year?
A2: The day is longest during the year at the summer solstice. This typically occurs

Q3: Where is bigger, the moon or the sun?
A3: The sun is significantly bigger than the moon. 

The sun has a diameter of

------------------------------
```

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-06-06 20:21:13 +08:00
committed by GitHub
parent c7f1c59911
commit 11a7df4270
5 changed files with 600 additions and 39 deletions

View File

@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_dp_group, get_pp_group
from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
@@ -43,7 +43,8 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
@@ -84,6 +85,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
@@ -103,6 +105,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
@@ -151,6 +154,7 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
@@ -188,6 +192,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
@@ -213,6 +218,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
@@ -268,6 +274,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
else:
for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear()
self.inputs_embeds = inputs_embeds
if input_positions:
self.input_positions = input_positions
@@ -329,6 +336,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
else:
self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None
@@ -368,6 +376,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
self.lora_index_mapping = []
self.lora_prompt_mapping = []
def __repr__(self) -> str:
return (f"InterDataForSeqGroup("
f"request_id={self.request_id}, "
f"seq_ids={self.seq_ids}, "
f"is_prompt={self.is_prompt}, "
f"block_tables={self.block_tables}, "
f"computed_block_nums={self.computed_block_nums}, "
f"n_seqs={self.n_seqs}, "
f"input_tokens={self.input_tokens}, "
f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, "
f"query_lens={self.query_lens}, "
f"context_lens={self.context_lens}, "
f"multi_modal_kwargs={self.multi_modal_kwargs}")
def __init__(self,
runner,
finished_requests_ids: Optional[List[str]] = None):
@@ -492,11 +520,30 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = [
flatten_2d_lists(inter_data.input_tokens)
for inter_data in self.inter_data_list
]
if not input_tokens:
input_tokens = list[int]()
inputs_embeds_list = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds_list.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_list) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
if not input_tokens and inputs_embeds is None:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
@@ -548,10 +595,6 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
else:
graph_pad_size = -1
#print(f"before tensor input_tokens: {input_tokens}")
#print(f"before tensor input_positions: {input_positions}")
#print(f"before list seq_lens: {seq_lens}")
input_tokens = flatten_2d_lists(input_tokens)
if input_positions:
input_positions = flatten_2d_lists(input_positions)
if graph_pad_size != -1 and not is_prompt:
@@ -563,6 +606,10 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.runner.device)
token_types_tensor = torch.tensor(token_types,
dtype=torch.long,
device=self.runner.device) \
if token_types else None
if mrope_input_positions is not None:
input_positions_tensor = torch.tensor(mrope_input_positions,
dtype=torch.long,
@@ -613,6 +660,8 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
return self.model_input_cls(
input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds,
token_types=token_types_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
@@ -645,13 +694,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len]
# Fixme: this is for the version compatibility, remove this once vllm v0.8.5 does not be supported.
if not hasattr(seq_data,
"prompt_embeds") or seq_data.prompt_embeds is None:
tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_embeds = None
else:
tokens = [0] * (seq_len - context_len)
prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
@@ -1379,6 +1438,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
model_kwargs["attn_metadata"] = model_input.attn_metadata
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
@@ -1422,34 +1482,61 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
hidden_or_intermediate_states,
)
if self.is_driver_worker:
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the
# latency from the start time of the driver worker to the end
# time of the driver worker. The model forward time will then
# end up covering the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
if self.is_driver_worker:
sampled = broadcast_tensor_dict(
{"token_ids": output.sampled_token_ids})
else:
sampled = broadcast_tensor_dict()
if sampled["token_ids"] is not None:
sampled_token_embeds = self.model.get_input_embeddings(
sampled["token_ids"].squeeze(1))
if self.is_driver_worker:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs
output.sampled_token_embeds = sampled_token_embeds
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[
0].output_embed = token_embed
if not self.is_driver_worker:
return []
if model_input.async_callback is not None:
model_input.async_callback()
# Sample the next token.
output = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
orig_model_forward_time = 0.0
if intermediate_tensors is not None:
orig_model_forward_time = intermediate_tensors.tensors.get(
"model_forward_time", torch.tensor(0.0)).item()
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None