Input_embeds support (#2052)
This commit is contained in:
@@ -29,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
class GenerateReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
# The token ids for text; one can specify either text or input_ids
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
@@ -60,10 +62,16 @@ class GenerateReqInput:
|
||||
] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
if (
|
||||
self.text is None and self.input_ids is None and self.input_embeds is None
|
||||
) or (
|
||||
self.text is not None
|
||||
and self.input_ids is not None
|
||||
and self.input_embeds is not None
|
||||
):
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
raise ValueError(
|
||||
"Either text, input_ids or input_embeds should be provided."
|
||||
)
|
||||
|
||||
# Derive the batch size
|
||||
if self.text is not None:
|
||||
@@ -73,13 +81,21 @@ class GenerateReqInput:
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.text)
|
||||
else:
|
||||
self.input_embeds = None
|
||||
elif self.input_ids is not None:
|
||||
if isinstance(self.input_ids[0], int):
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.input_ids)
|
||||
self.input_embeds = None
|
||||
else:
|
||||
if isinstance(self.input_embeds[0][0], float):
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.batch_size = len(self.input_embeds)
|
||||
|
||||
# Handle parallel sampling
|
||||
# When parallel sampling is used, we always treat the input as a batch.
|
||||
@@ -202,6 +218,8 @@ class TokenizedGenerateReqInput:
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
# The input embeds
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[str] = None
|
||||
@@ -218,6 +236,8 @@ class EmbeddingReqInput:
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
# Dummy input embeds for compatibility
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
|
||||
Reference in New Issue
Block a user