[ModelRunner] Support embedding inputs (#916)
### What this PR does / why we need it?
- Adds support for passing prompt_embeds to LLM.generate as
```bash
llm.generate({"prompt_embeds": input_embeds}, sampling_params)
```
or
```bash
llm.generate(
[{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)
```
- Add `prompt_embeds` to examples
### How was this patch tested?
CI passed with new added/existing test.
and I have test with the example script in this pr, and the output seems
looks good:
```bash
[Single Inference Output]
------------------------------
The capital of France is Paris. Paris is the largest city in France and is
------------------------------
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3966.87it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.99it/s, est. speed input: 177.08 toks/s, output: 63.91 toks/s]
[Batch Inference Outputs]
------------------------------
Q1: Please tell me about the capital of France.
A1: The capital of France is Paris. It is located in the northern part of the
Q2: When is the day longest during the year?
A2: The day is longest during the year at the summer solstice. This typically occurs
Q3: Where is bigger, the moon or the sun?
A3: The sun is significantly bigger than the moon.
The sun has a diameter of
------------------------------
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -33,7 +33,7 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.distributed import get_dp_group, get_pp_group
|
||||
from vllm.distributed import broadcast_tensor_dict, get_dp_group, get_pp_group
|
||||
from vllm.distributed.kv_transfer import get_kv_transfer_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
@@ -43,7 +43,8 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
||||
get_sampler)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
@@ -84,6 +85,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
||||
additional fields.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_types: Optional[torch.Tensor] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
@@ -103,6 +105,7 @@ class ModelInputForNPU(ModelRunnerInputBase):
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
@@ -151,6 +154,7 @@ class ModelInputForNPUWithSamplingMetadata(ModelInputForNPU):
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
@@ -188,6 +192,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
def simple_reinit(self):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.inputs_embeds = None # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.token_types[0].clear() # type: ignore
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
@@ -213,6 +218,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
# Input tokens and positions.
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
token_types: Optional[List[List[int]]] = None,
|
||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||
@@ -268,6 +274,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_tokens[seq_id].clear()
|
||||
self.inputs_embeds = inputs_embeds
|
||||
|
||||
if input_positions:
|
||||
self.input_positions = input_positions
|
||||
@@ -329,6 +336,7 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.inputs_embeds = inputs_embeds
|
||||
self.input_positions = input_positions or []
|
||||
self.token_types = token_types or []
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
@@ -368,6 +376,26 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
self.lora_index_mapping = []
|
||||
self.lora_prompt_mapping = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"InterDataForSeqGroup("
|
||||
f"request_id={self.request_id}, "
|
||||
f"seq_ids={self.seq_ids}, "
|
||||
f"is_prompt={self.is_prompt}, "
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"computed_block_nums={self.computed_block_nums}, "
|
||||
f"n_seqs={self.n_seqs}, "
|
||||
f"input_tokens={self.input_tokens}, "
|
||||
f"inputs_embeds.shape="
|
||||
f"{getattr(self.inputs_embeds, 'shape', None)}, "
|
||||
f"input_positions={self.input_positions}, "
|
||||
f"token_types={self.token_types}, "
|
||||
f"mrope_input_positions={self.mrope_input_positions}, "
|
||||
f"seq_lens={self.seq_lens}, "
|
||||
f"orig_seq_lens={self.orig_seq_lens}, "
|
||||
f"query_lens={self.query_lens}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"multi_modal_kwargs={self.multi_modal_kwargs}")
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
finished_requests_ids: Optional[List[str]] = None):
|
||||
@@ -492,11 +520,30 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
create on-device tensors.
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = [
|
||||
flatten_2d_lists(inter_data.input_tokens)
|
||||
for inter_data in self.inter_data_list
|
||||
]
|
||||
if not input_tokens:
|
||||
input_tokens = list[int]()
|
||||
inputs_embeds_list = list[torch.Tensor]()
|
||||
token_types = list[int]()
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
input_tokens.extend(cur_input_tokens)
|
||||
for cur_token_types in inter_data.token_types:
|
||||
token_types.extend(cur_token_types)
|
||||
if inter_data.inputs_embeds is not None:
|
||||
inputs_embeds_list.append(
|
||||
inter_data.inputs_embeds.to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device))
|
||||
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
if len(inputs_embeds_list) == 0:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = torch.cat(inputs_embeds_list, dim=0).to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device)
|
||||
assert len(inputs_embeds) == len(input_tokens)
|
||||
|
||||
if not input_tokens and inputs_embeds is None:
|
||||
# This may happen when all prefill requests hit
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
@@ -548,10 +595,6 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
else:
|
||||
graph_pad_size = -1
|
||||
|
||||
#print(f"before tensor input_tokens: {input_tokens}")
|
||||
#print(f"before tensor input_positions: {input_positions}")
|
||||
#print(f"before list seq_lens: {seq_lens}")
|
||||
input_tokens = flatten_2d_lists(input_tokens)
|
||||
if input_positions:
|
||||
input_positions = flatten_2d_lists(input_positions)
|
||||
if graph_pad_size != -1 and not is_prompt:
|
||||
@@ -563,6 +606,10 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
token_types_tensor = torch.tensor(token_types,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device) \
|
||||
if token_types else None
|
||||
if mrope_input_positions is not None:
|
||||
input_positions_tensor = torch.tensor(mrope_input_positions,
|
||||
dtype=torch.long,
|
||||
@@ -613,6 +660,8 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
inputs_embeds=inputs_embeds,
|
||||
token_types=token_types_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
@@ -645,13 +694,23 @@ class ModelInputForNPUBuilder(ModelRunnerInputBuilderBase[ModelInputForNPU]):
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
# Compute tokens.
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
# Fixme: this is for the version compatibility, remove this once vllm v0.8.5 does not be supported.
|
||||
if not hasattr(seq_data,
|
||||
"prompt_embeds") or seq_data.prompt_embeds is None:
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
prompt_embeds = None
|
||||
else:
|
||||
tokens = [0] * (seq_len - context_len)
|
||||
prompt_embeds = seq_data.get_token_embeddings(
|
||||
)[context_len:seq_len]
|
||||
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.inputs_embeds = prompt_embeds
|
||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||
inter_data.token_types[seq_idx].extend(
|
||||
token_types if token_types else [])
|
||||
@@ -1379,6 +1438,7 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
model_kwargs["attn_metadata"] = model_input.attn_metadata
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=model_input.inputs_embeds,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
@@ -1422,34 +1482,61 @@ class NPUModelRunner(NPUModelRunnerBase[ModelInputForNPUWithSamplingMetadata]):
|
||||
hidden_or_intermediate_states,
|
||||
)
|
||||
|
||||
if self.is_driver_worker:
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
assert isinstance(self.sampler, Sampler)
|
||||
orig_include_gpu_probs = self.sampler.include_gpu_probs_tensor
|
||||
if model_input.inputs_embeds is not None:
|
||||
self.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the
|
||||
# latency from the start time of the driver worker to the end
|
||||
# time of the driver worker. The model forward time will then
|
||||
# end up covering the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if model_input.inputs_embeds is not None:
|
||||
if self.is_driver_worker:
|
||||
sampled = broadcast_tensor_dict(
|
||||
{"token_ids": output.sampled_token_ids})
|
||||
else:
|
||||
sampled = broadcast_tensor_dict()
|
||||
if sampled["token_ids"] is not None:
|
||||
sampled_token_embeds = self.model.get_input_embeddings(
|
||||
sampled["token_ids"].squeeze(1))
|
||||
if self.is_driver_worker:
|
||||
self.sampler.include_gpu_probs_tensor = \
|
||||
orig_include_gpu_probs
|
||||
|
||||
output.sampled_token_embeds = sampled_token_embeds
|
||||
|
||||
for token_embed, sequence_group_output in zip(
|
||||
output.sampled_token_embeds, output.outputs):
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
sequence_group_output.samples[
|
||||
0].output_embed = token_embed
|
||||
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
|
||||
Reference in New Issue
Block a user