[feature] Prompt Embeddings Support for v1 Engine (#3026)

### What this PR does / why we need it?
this PR based on
[19746](https://github.com/vllm-project/vllm/issues/19746), support
Prompt Embeddings for v1 engine on NPU

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

```python
python examples/prompt_embed_inference.py
```


- vLLM version: v0.11.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.1

---------

Signed-off-by: jesse <szxfml@gmail.com>
This commit is contained in:
Song Zhixin
2025-10-30 17:15:57 +08:00
committed by GitHub
parent f6149f3894
commit 216fc0e8e4
5 changed files with 447 additions and 17 deletions

View File

@@ -88,6 +88,7 @@ jobs:
# We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run
# the test separately.
pytest -sv tests/e2e/singlecard/test_completion_with_prompt_embeds.py
pytest -sv tests/e2e/singlecard/test_aclgraph.py
pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py
pytest -sv tests/e2e/singlecard/test_bge_model.py

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates how to generate prompt embeddings using
Hugging Face Transformers and use them as input to vLLM
for both single and batch inference.
Model: meta-llama/Llama-3.2-1B-Instruct
Note: This model is gated on Hugging Face Hub.
You must request access to use it:
https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
Requirements:
- vLLM
- transformers
Run:
python examples/prompt_embed_inference.py
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from vllm import LLM
def init_tokenizer_and_llm(model_name: str):
llm = LLM(model=model_name, enable_prompt_embeds=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
embedding_layer = transformers_model.get_input_embeddings()
return tokenizer, embedding_layer, llm
def get_prompt_embeds(
chat: list[dict[str, str]],
tokenizer: PreTrainedTokenizer,
embedding_layer: torch.nn.Module,
):
token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt"
)
prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds
def single_prompt_inference(
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
):
chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
outputs = llm.generate(
{
"prompt_embeds": prompt_embeds,
}
)
print("\n[Single Inference Output]")
print("-" * 30)
for o in outputs:
print(o.outputs[0].text)
print("-" * 30)
def batch_prompt_inference(
llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
):
chats = [
[{"role": "user", "content": "Please tell me about the capital of France."}],
[{"role": "user", "content": "When is the day longest during the year?"}],
[{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
]
prompt_embeds_list = [
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
]
outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
print("\n[Batch Inference Outputs]")
print("-" * 30)
for i, o in enumerate(outputs):
print(f"Q{i + 1}: {chats[i][0]['content']}")
print(f"A{i + 1}: {o.outputs[0].text}\n")
print("-" * 30)
def main():
model_name = "meta-llama/Llama-3.2-1B-Instruct"
tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name)
single_prompt_inference(llm, tokenizer, embedding_layer)
batch_prompt_inference(llm, tokenizer, embedding_layer)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,197 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import os
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"]
def get_prompt_embeds(chat, tokenizer, embedding_layer):
"""Convert chat messages to prompt embeddings."""
token_ids = tokenizer.apply_chat_template(chat,
add_generation_prompt=True,
return_tensors='pt')
prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds
@pytest.mark.parametrize("model_name", MODELS)
def test_single_prompt_embeds_inference(model_name):
"""Test single prompt inference with prompt embeddings."""
# Prepare prompt embeddings
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
embedding_layer = transformers_model.get_input_embeddings()
chat = [{
"role": "user",
"content": "Please tell me about the capital of France."
}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
# Run inference with prompt embeddings
with VllmRunner(
model_name,
enable_prompt_embeds=True,
enforce_eager=True,
) as vllm_runner:
outputs = vllm_runner.model.generate({
"prompt_embeds": prompt_embeds,
})
# Verify output
assert len(outputs) == 1
assert len(outputs[0].outputs) > 0
assert len(outputs[0].outputs[0].text) > 0
print(f"\n[Single Inference Output]: {outputs[0].outputs[0].text}")
@pytest.mark.parametrize("model_name", MODELS)
def test_batch_prompt_embeds_inference(model_name):
"""Test batch prompt inference with prompt embeddings."""
# Prepare prompt embeddings
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
embedding_layer = transformers_model.get_input_embeddings()
chats = [[{
"role": "user",
"content": "Please tell me about the capital of France."
}],
[{
"role": "user",
"content": "When is the day longest during the year?"
}],
[{
"role": "user",
"content": "Where is bigger, the moon or the sun?"
}]]
prompt_embeds_list = [
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
]
# Run batch inference with prompt embeddings
with VllmRunner(
model_name,
enable_prompt_embeds=True,
enforce_eager=True,
) as vllm_runner:
outputs = vllm_runner.model.generate([{
"prompt_embeds": embeds
} for embeds in prompt_embeds_list])
# Verify outputs
assert len(outputs) == len(chats)
for i, output in enumerate(outputs):
assert len(output.outputs) > 0
assert len(output.outputs[0].text) > 0
print(f"\nQ{i+1}: {chats[i][0]['content']}")
print(f"A{i+1}: {output.outputs[0].text}")
@pytest.mark.parametrize("model_name", MODELS)
def test_prompt_embeds_with_aclgraph(model_name):
"""Test prompt embeddings with ACL graph enabled vs disabled."""
# Prepare prompt embeddings
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
embedding_layer = transformers_model.get_input_embeddings()
chat = [{"role": "user", "content": "What is the capital of China?"}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
# Run with ACL graph enabled (enforce_eager=False)
with VllmRunner(
model_name,
enable_prompt_embeds=True,
enforce_eager=False,
) as vllm_aclgraph_runner:
aclgraph_outputs = vllm_aclgraph_runner.model.generate({
"prompt_embeds":
prompt_embeds,
})
# Run with ACL graph disabled (enforce_eager=True)
with VllmRunner(
model_name,
enable_prompt_embeds=True,
enforce_eager=True,
) as vllm_eager_runner:
eager_outputs = vllm_eager_runner.model.generate({
"prompt_embeds":
prompt_embeds,
})
# Verify both produce valid outputs
assert len(aclgraph_outputs) == 1
assert len(eager_outputs) == 1
assert len(aclgraph_outputs[0].outputs[0].text) > 0
assert len(eager_outputs[0].outputs[0].text) > 0
print("\n[ACL Graph Output]:", aclgraph_outputs[0].outputs[0].text)
print("[Eager Output]:", eager_outputs[0].outputs[0].text)
# Note: Outputs may differ slightly due to different execution paths,
# but both should be valid responses
@pytest.mark.parametrize("model_name", MODELS)
def test_mixed_prompt_embeds_and_text(model_name):
"""Test mixed inputs with both prompt embeddings and text prompts."""
# Prepare prompt embeddings for first request
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformers_model = AutoModelForCausalLM.from_pretrained(model_name)
embedding_layer = transformers_model.get_input_embeddings()
chat = [{"role": "user", "content": "What is AI?"}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
# Prepare text prompt for second request
text_prompt = "What is machine learning?"
# Run inference with mixed inputs
with VllmRunner(
model_name,
enable_prompt_embeds=True,
enforce_eager=True,
) as vllm_runner:
# Test prompt embeddings
embeds_output = vllm_runner.model.generate({
"prompt_embeds":
prompt_embeds,
})
# Test text prompt
text_output = vllm_runner.model.generate(text_prompt)
# Verify both types of inputs work
assert len(embeds_output) == 1
assert len(text_output) == 1
assert len(embeds_output[0].outputs[0].text) > 0
assert len(text_output[0].outputs[0].text) > 0
print("\n[Prompt Embeds Output]:", embeds_output[0].outputs[0].text)
print("[Text Prompt Output]:", text_output[0].outputs[0].text)

View File

@@ -72,7 +72,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import cdiv
from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
@@ -346,11 +346,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_multimodal_model = self.model_config.is_multimodal_model
self.is_pooling_model = self.model_config.pooler_config is not None
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.model_config.get_hidden_size()),
self.enable_prompt_embeds = self.model_config.enable_prompt_embeds
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds = self._make_buffer(
self.max_num_tokens,
self.model_config.get_hidden_size(),
dtype=self.dtype,
device=self.device)
numpy=False)
self.is_token_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.bool)
# Set up Attention
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
"index_topk")
@@ -721,6 +726,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
prompt_embeds=new_req_data.prompt_embeds,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
@@ -999,7 +1005,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[index]
num_scheduled_tokens = \
scheduler_output.num_scheduled_tokens[req_id]
num_prompt_tokens = len(req.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
req.prompt_token_ids, req.prompt_embeds)
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
prompt_part_len = max(0,
@@ -1274,6 +1281,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
return
# Async scheduling case, where some decode requests from the previous
@@ -1301,6 +1311,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens],
non_blocking=True)
if self.is_multimodal_model or self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
if num_commmon_tokens == 0:
# No requests in common with the previous iteration
# So input_ids_cpu will have all the input ids.
@@ -1314,6 +1327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
0],
non_blocking=True)
self.is_token_ids.gpu[:num_commmon_tokens] = True
return
# Upload the index tensors asynchronously
# so the scatter can be non-blocking.
@@ -1481,15 +1495,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
token_indices_tensor = torch.from_numpy(token_indices)
# Prepare input_ids.
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
token_indices_tensor,
out=self.input_ids_cpu[:total_num_scheduled_tokens])
is_token_ids = self.input_batch.is_token_ids.flatten()
torch.index_select(
is_token_ids,
0,
token_indices_tensor,
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
# the InputBatch, we need to fill in the prompt embeds into the expected
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or
self.enable_prompt_embeds):
output_idx = 0
for req_idx in range(num_reqs):
num_sched = num_scheduled_tokens[req_idx]
# Skip if this request doesn't have embeddings
if req_idx not in self.input_batch.req_prompt_embeds:
output_idx += num_sched
continue
# Skip if no tokens scheduled
if num_sched <= 0:
output_idx += num_sched
continue
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
# Skip if trying to read beyond available embeddings
if start_pos >= req_embeds.shape[0]:
output_idx += num_sched
continue
# Copy available embeddings
end_pos = start_pos + num_sched
actual_end = min(end_pos, req_embeds.shape[0])
actual_num_sched = actual_end - start_pos
if actual_num_sched > 0:
self.inputs_embeds.cpu[output_idx:output_idx +
actual_num_sched].copy_(
req_embeds[start_pos:actual_end]
)
output_idx += num_sched
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
@@ -1573,9 +1633,34 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:total_num_scheduled_tokens].copy_(
self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_(
inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
input_ids = None
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
# Get the input embeddings for the tokens that are not input embeds,
# then put them into the appropriate positions.
# TODO(qthequartermasterman): Since even when prompt embeds are
# enabled, (a) not all requests will use prompt embeds, and (b)
# after the initial prompt is processed, the rest of the generated
# tokens will be token ids, it is not desirable to have the
# embedding layer outside of the acl graph all the time. The v0
# engine avoids this by "double compiling" the acl graph, once
# with input_ids and again with inputs_embeds, for all num_tokens.
# If a batch only has token ids, then including the embedding layer
# in the acl graph will be more performant (like in the else case
# below).
token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens] \
.nonzero(as_tuple=False) \
.squeeze(1)
# Some tokens ids may need to become embeds
if token_ids_idx.numel() > 0:
token_ids = self.input_ids[token_ids_idx]
tokens_to_embeds = self.model.get_input_embeddings(
input_ids=token_ids)
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
input_ids = None
else:
# For text-only models, we use token ids as input.
@@ -2404,6 +2489,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.is_token_ids[req_idx,
start_idx:end_idx] = True
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
@@ -2729,7 +2816,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens):
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
elif self.enable_prompt_embeds:
input_ids = None
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
@@ -3996,6 +4086,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Get metadata for this request.
request = self.requests[req_id]
if request.prompt_token_ids is None:
# Prompt logprobs is incompatible with prompt embeddings
continue
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)

View File

@@ -29,6 +29,7 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
@@ -51,7 +52,7 @@ else:
class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
prompt_token_ids: Optional[list[int]]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
@@ -70,9 +71,11 @@ class CachedRequestState:
mm_hashes: Optional[list[PlaceholderRange]] = None
lora_request: Optional[LoRARequest] = None
prompt_embeds: Optional[torch.Tensor] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
self.prompt_token_ids, self.prompt_embeds)
@property
def num_tokens(self) -> int:
@@ -91,6 +94,10 @@ class CachedRequestState:
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None:
raise ValueError(
f"Tried to access token index {idx}, but that token was "
"provided via prompt_embeds, and its ID is unknown.")
return self.prompt_token_ids[idx]
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
@@ -139,6 +146,14 @@ class InputBatch:
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
@@ -345,15 +360,23 @@ class InputBatch:
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
request.prompt_token_ids, request.prompt_embeds)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
if request.prompt_token_ids is not None:
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
self.is_token_ids[req_index, :num_prompt_tokens] = True
else:
self.is_token_ids[req_index, :num_prompt_tokens] = False
if request.prompt_embeds is not None:
self.req_prompt_embeds[req_index] = request.prompt_embeds
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
self.is_token_ids[req_index, start_idx:end_idx] = True
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
@@ -553,6 +576,20 @@ class InputBatch:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
# Swap prompt embeddings if they exist
embeds_i1 = self.req_prompt_embeds.get(i1)
embeds_i2 = self.req_prompt_embeds.get(i2)
if embeds_i1 is not None:
self.req_prompt_embeds[i2] = embeds_i1
else:
self.req_prompt_embeds.pop(i2, None)
if embeds_i2 is not None:
self.req_prompt_embeds[i1] = embeds_i2
else:
self.req_prompt_embeds.pop(i1, None)
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
@@ -631,6 +668,11 @@ class InputBatch:
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
last_req_index, :num_tokens]
if last_req_index in self.req_prompt_embeds:
self.req_prompt_embeds[
empty_index] = self.req_prompt_embeds.pop(last_req_index)
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]