Input_embeds support (#2052)

This commit is contained in:
Rin Intachuen
2024-11-25 19:35:04 -05:00
committed by GitHub
parent 1f76fc6e3f
commit 1aea19f64b
9 changed files with 204 additions and 15 deletions

View File

@@ -29,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
class GenerateReqInput:
# The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can specify either text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
@@ -60,10 +62,16 @@ class GenerateReqInput:
] = None
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
if (
self.text is None and self.input_ids is None and self.input_embeds is None
) or (
self.text is not None
and self.input_ids is not None
and self.input_embeds is not None
):
raise ValueError("Either text or input_ids should be provided.")
raise ValueError(
"Either text, input_ids or input_embeds should be provided."
)
# Derive the batch size
if self.text is not None:
@@ -73,13 +81,21 @@ class GenerateReqInput:
else:
self.is_single = False
self.batch_size = len(self.text)
else:
self.input_embeds = None
elif self.input_ids is not None:
if isinstance(self.input_ids[0], int):
self.is_single = True
self.batch_size = 1
else:
self.is_single = False
self.batch_size = len(self.input_ids)
self.input_embeds = None
else:
if isinstance(self.input_embeds[0][0], float):
self.is_single = True
self.batch_size = 1
else:
self.batch_size = len(self.input_embeds)
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
@@ -202,6 +218,8 @@ class TokenizedGenerateReqInput:
# LoRA related
lora_path: Optional[str] = None # None means just use the base model
# The input embeds
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Session id info for continual prompting
session_id: Optional[str] = None
@@ -218,6 +236,8 @@ class EmbeddingReqInput:
rid: Optional[Union[List[str], str]] = None
# Dummy sampling params for compatibility
sampling_params: Union[List[Dict], Dict] = None
# Dummy input embeds for compatibility
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (