[ModelRunner] Support embedding inputs (#916)
### What this PR does / why we need it?
- Adds support for passing prompt_embeds to LLM.generate as
```bash
llm.generate({"prompt_embeds": input_embeds}, sampling_params)
```
or
```bash
llm.generate(
[{"prompt_embeds": input_embeds} for input_embeds in inputs_embeds], sampling_params
)
```
- Add `prompt_embeds` to examples
### How was this patch tested?
CI passed with new added/existing test.
and I have test with the example script in this pr, and the output seems
looks good:
```bash
[Single Inference Output]
------------------------------
The capital of France is Paris. Paris is the largest city in France and is
------------------------------
Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3966.87it/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.99it/s, est. speed input: 177.08 toks/s, output: 63.91 toks/s]
[Batch Inference Outputs]
------------------------------
Q1: Please tell me about the capital of France.
A1: The capital of France is Paris. It is located in the northern part of the
Q2: When is the day longest during the year?
A2: The day is longest during the year at the summer solstice. This typically occurs
Q3: Where is bigger, the moon or the sun?
A3: The sun is significantly bigger than the moon.
The sun has a diameter of
------------------------------
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
83
examples/prompt_embedding_inference.py
Normal file
83
examples/prompt_embedding_inference.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import torch
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizer)
|
||||
from vllm import LLM
|
||||
|
||||
|
||||
def init_tokenizer_and_llm(model_name: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
embedding_layer = transformers_model.get_input_embeddings()
|
||||
llm = LLM(model=model_name, enable_prompt_embeds=True)
|
||||
return tokenizer, embedding_layer, llm
|
||||
|
||||
|
||||
def get_prompt_embeds(chat: list[dict[str,
|
||||
str]], tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
token_ids = tokenizer.apply_chat_template(chat,
|
||||
add_generation_prompt=True,
|
||||
return_tensors='pt')
|
||||
prompt_embeds = embedding_layer(token_ids).squeeze(0)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
chat = [{
|
||||
"role": "user",
|
||||
"content": "Please tell me about the capital of France."
|
||||
}]
|
||||
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt_embeds": prompt_embeds,
|
||||
})
|
||||
|
||||
print("\n[Single Inference Output]")
|
||||
print("-" * 30)
|
||||
for o in outputs:
|
||||
print(o.outputs[0].text)
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
|
||||
embedding_layer: torch.nn.Module):
|
||||
chats = [[{
|
||||
"role": "user",
|
||||
"content": "Please tell me about the capital of France."
|
||||
}],
|
||||
[{
|
||||
"role": "user",
|
||||
"content": "When is the day longest during the year?"
|
||||
}],
|
||||
[{
|
||||
"role": "user",
|
||||
"content": "Where is bigger, the moon or the sun?"
|
||||
}]]
|
||||
|
||||
prompt_embeds_list = [
|
||||
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
|
||||
]
|
||||
|
||||
outputs = llm.generate([{
|
||||
"prompt_embeds": embeds
|
||||
} for embeds in prompt_embeds_list])
|
||||
|
||||
print("\n[Batch Inference Outputs]")
|
||||
print("-" * 30)
|
||||
for i, o in enumerate(outputs):
|
||||
print(f"Q{i+1}: {chats[i][0]['content']}")
|
||||
print(f"A{i+1}: {o.outputs[0].text}\n")
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def main():
|
||||
model_name = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
|
||||
single_prompt_inference(llm, tokenizer, embedding_layer)
|
||||
batch_prompt_inference(llm, tokenizer, embedding_layer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user