From 216fc0e8e444a05a765116c940d15520da82378c Mon Sep 17 00:00:00 2001 From: Song Zhixin Date: Thu, 30 Oct 2025 17:15:57 +0800 Subject: [PATCH] [feature] Prompt Embeddings Support for v1 Engine (#3026) ### What this PR does / why we need it? this PR based on [19746](https://github.com/vllm-project/vllm/issues/19746), support Prompt Embeddings for v1 engine on NPU ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? ```python python examples/prompt_embed_inference.py ``` - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.1 --------- Signed-off-by: jesse --- .github/workflows/_e2e_test.yaml | 1 + examples/prompt_embed_inference.py | 97 +++++++++ .../test_completion_with_prompt_embeds.py | 197 ++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 115 +++++++++- vllm_ascend/worker/npu_input_batch.py | 54 ++++- 5 files changed, 447 insertions(+), 17 deletions(-) create mode 100644 examples/prompt_embed_inference.py create mode 100644 tests/e2e/singlecard/test_completion_with_prompt_embeds.py diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index c43fe7c0..e70e63c3 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -88,6 +88,7 @@ jobs: # We found that if running aclgraph tests in batch, it will cause AclmdlRICaptureBegin error. So we run # the test separately. + pytest -sv tests/e2e/singlecard/test_completion_with_prompt_embeds.py pytest -sv tests/e2e/singlecard/test_aclgraph.py pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py pytest -sv tests/e2e/singlecard/test_bge_model.py diff --git a/examples/prompt_embed_inference.py b/examples/prompt_embed_inference.py new file mode 100644 index 00000000..67c9741d --- /dev/null +++ b/examples/prompt_embed_inference.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Demonstrates how to generate prompt embeddings using +Hugging Face Transformers and use them as input to vLLM +for both single and batch inference. + +Model: meta-llama/Llama-3.2-1B-Instruct +Note: This model is gated on Hugging Face Hub. + You must request access to use it: + https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct + +Requirements: +- vLLM +- transformers + +Run: + python examples/prompt_embed_inference.py +""" + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer + +from vllm import LLM + + +def init_tokenizer_and_llm(model_name: str): + llm = LLM(model=model_name, enable_prompt_embeds=True) + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + return tokenizer, embedding_layer, llm + + +def get_prompt_embeds( + chat: list[dict[str, str]], + tokenizer: PreTrainedTokenizer, + embedding_layer: torch.nn.Module, +): + token_ids = tokenizer.apply_chat_template( + chat, add_generation_prompt=True, return_tensors="pt" + ) + prompt_embeds = embedding_layer(token_ids).squeeze(0) + return prompt_embeds + + +def single_prompt_inference( + llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module +): + chat = [{"role": "user", "content": "Please tell me about the capital of France."}] + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) + + outputs = llm.generate( + { + "prompt_embeds": prompt_embeds, + } + ) + + print("\n[Single Inference Output]") + print("-" * 30) + for o in outputs: + print(o.outputs[0].text) + print("-" * 30) + + +def batch_prompt_inference( + llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module +): + chats = [ + [{"role": "user", "content": "Please tell me about the capital of France."}], + [{"role": "user", "content": "When is the day longest during the year?"}], + [{"role": "user", "content": "Where is bigger, the moon or the sun?"}], + ] + + prompt_embeds_list = [ + get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats + ] + + outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list]) + + print("\n[Batch Inference Outputs]") + print("-" * 30) + for i, o in enumerate(outputs): + print(f"Q{i + 1}: {chats[i][0]['content']}") + print(f"A{i + 1}: {o.outputs[0].text}\n") + print("-" * 30) + + +def main(): + model_name = "meta-llama/Llama-3.2-1B-Instruct" + tokenizer, embedding_layer, llm = init_tokenizer_and_llm(model_name) + single_prompt_inference(llm, tokenizer, embedding_layer) + batch_prompt_inference(llm, tokenizer, embedding_layer) + + +if __name__ == "__main__": + main() diff --git a/tests/e2e/singlecard/test_completion_with_prompt_embeds.py b/tests/e2e/singlecard/test_completion_with_prompt_embeds.py new file mode 100644 index 00000000..b72dc0d0 --- /dev/null +++ b/tests/e2e/singlecard/test_completion_with_prompt_embeds.py @@ -0,0 +1,197 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import os + +import pytest +from transformers import AutoModelForCausalLM, AutoTokenizer + +from tests.e2e.conftest import VllmRunner + +os.environ["VLLM_USE_MODELSCOPE"] = "True" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] + + +def get_prompt_embeds(chat, tokenizer, embedding_layer): + """Convert chat messages to prompt embeddings.""" + token_ids = tokenizer.apply_chat_template(chat, + add_generation_prompt=True, + return_tensors='pt') + prompt_embeds = embedding_layer(token_ids).squeeze(0) + return prompt_embeds + + +@pytest.mark.parametrize("model_name", MODELS) +def test_single_prompt_embeds_inference(model_name): + """Test single prompt inference with prompt embeddings.""" + # Prepare prompt embeddings + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + + chat = [{ + "role": "user", + "content": "Please tell me about the capital of France." + }] + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) + + # Run inference with prompt embeddings + with VllmRunner( + model_name, + enable_prompt_embeds=True, + enforce_eager=True, + ) as vllm_runner: + outputs = vllm_runner.model.generate({ + "prompt_embeds": prompt_embeds, + }) + + # Verify output + assert len(outputs) == 1 + assert len(outputs[0].outputs) > 0 + assert len(outputs[0].outputs[0].text) > 0 + print(f"\n[Single Inference Output]: {outputs[0].outputs[0].text}") + + +@pytest.mark.parametrize("model_name", MODELS) +def test_batch_prompt_embeds_inference(model_name): + """Test batch prompt inference with prompt embeddings.""" + # Prepare prompt embeddings + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + + chats = [[{ + "role": "user", + "content": "Please tell me about the capital of France." + }], + [{ + "role": "user", + "content": "When is the day longest during the year?" + }], + [{ + "role": "user", + "content": "Where is bigger, the moon or the sun?" + }]] + + prompt_embeds_list = [ + get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats + ] + + # Run batch inference with prompt embeddings + with VllmRunner( + model_name, + enable_prompt_embeds=True, + enforce_eager=True, + ) as vllm_runner: + outputs = vllm_runner.model.generate([{ + "prompt_embeds": embeds + } for embeds in prompt_embeds_list]) + + # Verify outputs + assert len(outputs) == len(chats) + for i, output in enumerate(outputs): + assert len(output.outputs) > 0 + assert len(output.outputs[0].text) > 0 + print(f"\nQ{i+1}: {chats[i][0]['content']}") + print(f"A{i+1}: {output.outputs[0].text}") + + +@pytest.mark.parametrize("model_name", MODELS) +def test_prompt_embeds_with_aclgraph(model_name): + """Test prompt embeddings with ACL graph enabled vs disabled.""" + # Prepare prompt embeddings + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + + chat = [{"role": "user", "content": "What is the capital of China?"}] + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) + + # Run with ACL graph enabled (enforce_eager=False) + with VllmRunner( + model_name, + enable_prompt_embeds=True, + enforce_eager=False, + ) as vllm_aclgraph_runner: + aclgraph_outputs = vllm_aclgraph_runner.model.generate({ + "prompt_embeds": + prompt_embeds, + }) + + # Run with ACL graph disabled (enforce_eager=True) + with VllmRunner( + model_name, + enable_prompt_embeds=True, + enforce_eager=True, + ) as vllm_eager_runner: + eager_outputs = vllm_eager_runner.model.generate({ + "prompt_embeds": + prompt_embeds, + }) + + # Verify both produce valid outputs + assert len(aclgraph_outputs) == 1 + assert len(eager_outputs) == 1 + assert len(aclgraph_outputs[0].outputs[0].text) > 0 + assert len(eager_outputs[0].outputs[0].text) > 0 + + print("\n[ACL Graph Output]:", aclgraph_outputs[0].outputs[0].text) + print("[Eager Output]:", eager_outputs[0].outputs[0].text) + + # Note: Outputs may differ slightly due to different execution paths, + # but both should be valid responses + + +@pytest.mark.parametrize("model_name", MODELS) +def test_mixed_prompt_embeds_and_text(model_name): + """Test mixed inputs with both prompt embeddings and text prompts.""" + # Prepare prompt embeddings for first request + tokenizer = AutoTokenizer.from_pretrained(model_name) + transformers_model = AutoModelForCausalLM.from_pretrained(model_name) + embedding_layer = transformers_model.get_input_embeddings() + + chat = [{"role": "user", "content": "What is AI?"}] + prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) + + # Prepare text prompt for second request + text_prompt = "What is machine learning?" + + # Run inference with mixed inputs + with VllmRunner( + model_name, + enable_prompt_embeds=True, + enforce_eager=True, + ) as vllm_runner: + # Test prompt embeddings + embeds_output = vllm_runner.model.generate({ + "prompt_embeds": + prompt_embeds, + }) + + # Test text prompt + text_output = vllm_runner.model.generate(text_prompt) + + # Verify both types of inputs work + assert len(embeds_output) == 1 + assert len(text_output) == 1 + assert len(embeds_output[0].outputs[0].text) > 0 + assert len(text_output[0].outputs[0].text) > 0 + + print("\n[Prompt Embeds Output]:", embeds_output[0].outputs[0].text) + print("[Text Prompt Output]:", text_output[0].outputs[0].text) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1dcdb064..a14a2098 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -72,7 +72,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import cdiv +from vllm.utils import cdiv, length_from_prompt_token_ids_or_embeds from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -346,11 +346,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.is_multimodal_model = self.model_config.is_multimodal_model self.is_pooling_model = self.model_config.pooler_config is not None - if self.is_multimodal_model: - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.model_config.get_hidden_size()), + self.enable_prompt_embeds = self.model_config.enable_prompt_embeds + if self.is_multimodal_model or self.enable_prompt_embeds: + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, + self.model_config.get_hidden_size(), dtype=self.dtype, - device=self.device) + numpy=False) + self.is_token_ids = self._make_buffer(self.max_num_tokens, + dtype=torch.bool) + # Set up Attention self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, "index_topk") @@ -721,6 +726,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -999,7 +1005,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.num_computed_tokens_cpu[index] num_scheduled_tokens = \ scheduler_output.num_scheduled_tokens[req_id] - num_prompt_tokens = len(req.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + req.prompt_token_ids, req.prompt_embeds) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: prompt_part_len = max(0, @@ -1274,6 +1281,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + if self.is_multimodal_model or self.enable_prompt_embeds: + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) return # Async scheduling case, where some decode requests from the previous @@ -1301,6 +1311,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + if self.is_multimodal_model or self.enable_prompt_embeds: + self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) + self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) if num_commmon_tokens == 0: # No requests in common with the previous iteration # So input_ids_cpu will have all the input ids. @@ -1314,6 +1327,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], non_blocking=True) + self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously # so the scatter can be non-blocking. @@ -1481,15 +1495,61 @@ class NPUModelRunner(LoRAModelRunnerMixin): # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - + token_indices_tensor = torch.from_numpy(token_indices) # Prepare input_ids. # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, - torch.from_numpy(token_indices), + token_indices_tensor, out=self.input_ids_cpu[:total_num_scheduled_tokens]) + is_token_ids = self.input_batch.is_token_ids.flatten() + torch.index_select( + is_token_ids, + 0, + token_indices_tensor, + out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + + # Because we did not pre-allocate a massive prompt_embeds CPU tensor on + # the InputBatch, we need to fill in the prompt embeds into the expected + # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor. + if self.input_batch.req_prompt_embeds and (self.is_multimodal_model or + self.enable_prompt_embeds): + output_idx = 0 + for req_idx in range(num_reqs): + num_sched = num_scheduled_tokens[req_idx] + + # Skip if this request doesn't have embeddings + if req_idx not in self.input_batch.req_prompt_embeds: + output_idx += num_sched + continue + + # Skip if no tokens scheduled + if num_sched <= 0: + output_idx += num_sched + continue + + req_embeds = self.input_batch.req_prompt_embeds[req_idx] + start_pos = self.input_batch.num_computed_tokens_cpu[req_idx] + + # Skip if trying to read beyond available embeddings + if start_pos >= req_embeds.shape[0]: + output_idx += num_sched + continue + + # Copy available embeddings + end_pos = start_pos + num_sched + actual_end = min(end_pos, req_embeds.shape[0]) + actual_num_sched = actual_end - start_pos + + if actual_num_sched > 0: + self.inputs_embeds.cpu[output_idx:output_idx + + actual_num_sched].copy_( + req_embeds[start_pos:actual_end] + ) + + output_idx += num_sched self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -1573,9 +1633,34 @@ class NPUModelRunner(LoRAModelRunnerMixin): ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:total_num_scheduled_tokens].copy_( + self.inputs_embeds.gpu[:total_num_scheduled_tokens].copy_( inputs_embeds) - inputs_embeds = self.inputs_embeds[:num_input_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] + input_ids = None + elif self.enable_prompt_embeds and get_pp_group().is_first_rank: + # Get the input embeddings for the tokens that are not input embeds, + # then put them into the appropriate positions. + # TODO(qthequartermasterman): Since even when prompt embeds are + # enabled, (a) not all requests will use prompt embeds, and (b) + # after the initial prompt is processed, the rest of the generated + # tokens will be token ids, it is not desirable to have the + # embedding layer outside of the acl graph all the time. The v0 + # engine avoids this by "double compiling" the acl graph, once + # with input_ids and again with inputs_embeds, for all num_tokens. + # If a batch only has token ids, then including the embedding layer + # in the acl graph will be more performant (like in the else case + # below). + token_ids_idx = self.is_token_ids.gpu[:total_num_scheduled_tokens] \ + .nonzero(as_tuple=False) \ + .squeeze(1) + # Some tokens ids may need to become embeds + if token_ids_idx.numel() > 0: + token_ids = self.input_ids[token_ids_idx] + tokens_to_embeds = self.model.get_input_embeddings( + input_ids=token_ids) + self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds + + inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] input_ids = None else: # For text-only models, we use token ids as input. @@ -2404,6 +2489,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids + self.input_batch.is_token_ids[req_idx, + start_idx:end_idx] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx req_id = self.input_batch.req_ids[req_idx] @@ -2729,7 +2816,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens): if self.is_multimodal_model: input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] + elif self.enable_prompt_embeds: + input_ids = None + inputs_embeds = self.inputs_embeds.gpu[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None @@ -3996,6 +4086,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Get metadata for this request. request = self.requests[req_id] + if request.prompt_token_ids is None: + # Prompt logprobs is incompatible with prompt embeddings + continue num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( self.device, non_blocking=True) diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 48c712b1..846a4b29 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -29,6 +29,7 @@ from vllm.multimodal.inputs import (MultiModalFeatureSpec, MultiModalKwargsItems, PlaceholderRange) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, @@ -51,7 +52,7 @@ else: class CachedRequestState: req_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -70,9 +71,11 @@ class CachedRequestState: mm_hashes: Optional[list[PlaceholderRange]] = None lora_request: Optional[LoRARequest] = None + prompt_embeds: Optional[torch.Tensor] = None def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds) @property def num_tokens(self) -> int: @@ -91,6 +94,10 @@ class CachedRequestState: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: + if self.prompt_token_ids is None: + raise ValueError( + f"Tried to access token index {idx}, but that token was " + "provided via prompt_embeds, and its ID is unknown.") return self.prompt_token_ids[idx] elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] @@ -139,6 +146,14 @@ class InputBatch: pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.is_token_ids = torch.zeros((max_num_reqs, max_model_len), + device="cpu", + dtype=bool, + pin_memory=False) + # Store prompt embeddings per request to avoid OOM from large upfront + # allocation if max_model_len is big. + # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) + self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -345,15 +360,23 @@ class InputBatch: self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) + if request.prompt_token_ids is not None: + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + self.is_token_ids[req_index, :num_prompt_tokens] = True + else: + self.is_token_ids[req_index, :num_prompt_tokens] = False + if request.prompt_embeds is not None: + self.req_prompt_embeds[req_index] = request.prompt_embeds self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids - # Number of token ids in token_ids_cpu. + self.is_token_ids[req_index, start_idx:end_idx] = True + # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens # Number of tokens without spec decode tokens. @@ -553,6 +576,20 @@ class InputBatch: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + + # Swap prompt embeddings if they exist + embeds_i1 = self.req_prompt_embeds.get(i1) + embeds_i2 = self.req_prompt_embeds.get(i2) + if embeds_i1 is not None: + self.req_prompt_embeds[i2] = embeds_i1 + else: + self.req_prompt_embeds.pop(i2, None) + if embeds_i2 is not None: + self.req_prompt_embeds[i1] = embeds_i2 + else: + self.req_prompt_embeds.pop(i1, None) + swap_dict_values(self.generators, i1, i2) swap_dict_values(self.bad_words_token_ids, i1, i2) @@ -631,6 +668,11 @@ class InputBatch: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ + last_req_index, :num_tokens] + if last_req_index in self.req_prompt_embeds: + self.req_prompt_embeds[ + empty_index] = self.req_prompt_embeds.pop(last_req_index) self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index]