fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -200,6 +200,8 @@ class GenerateReqInput:
|
||||
self.text = [self.text]
|
||||
if self.input_ids is not None:
|
||||
self.input_ids = [self.input_ids]
|
||||
if self.input_embeds is not None:
|
||||
self.input_embeds = [self.input_embeds]
|
||||
|
||||
def _normalize_single_inputs(self):
|
||||
"""Normalize inputs for a single example."""
|
||||
@@ -324,7 +326,9 @@ class GenerateReqInput:
|
||||
new_rids = [f"{self.rid}_{i}" for i in range(num)]
|
||||
self.rid = new_rids
|
||||
elif isinstance(self.rid, list):
|
||||
if len(self.rid) != num:
|
||||
# Note: the length of rid shall be the same as the batch_size,
|
||||
# as the rid would be expanded for parallel sampling in tokenizer_manager
|
||||
if len(self.rid) != self.batch_size:
|
||||
raise ValueError(
|
||||
"The specified rids length mismatch with the batch_size for batch processing."
|
||||
)
|
||||
@@ -400,6 +404,9 @@ class GenerateReqInput:
|
||||
return GenerateReqInput(
|
||||
text=self.text[i] if self.text is not None else None,
|
||||
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
||||
input_embeds=(
|
||||
self.input_embeds[i] if self.input_embeds is not None else None
|
||||
),
|
||||
image_data=self.image_data[i],
|
||||
audio_data=self.audio_data[i],
|
||||
sampling_params=self.sampling_params[i],
|
||||
|
||||
Reference in New Issue
Block a user