Input_embeds support (#2052)
This commit is contained in:
@@ -29,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
class GenerateReqInput:
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
text: Optional[Union[List[str], str]] = None
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
# The token ids for text; one can specify either text or input_ids
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
@@ -60,10 +62,16 @@ class GenerateReqInput:
|
||||
] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
self.text is not None and self.input_ids is not None
|
||||
if (
|
||||
self.text is None and self.input_ids is None and self.input_embeds is None
|
||||
) or (
|
||||
self.text is not None
|
||||
and self.input_ids is not None
|
||||
and self.input_embeds is not None
|
||||
):
|
||||
raise ValueError("Either text or input_ids should be provided.")
|
||||
raise ValueError(
|
||||
"Either text, input_ids or input_embeds should be provided."
|
||||
)
|
||||
|
||||
# Derive the batch size
|
||||
if self.text is not None:
|
||||
@@ -73,13 +81,21 @@ class GenerateReqInput:
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.text)
|
||||
else:
|
||||
self.input_embeds = None
|
||||
elif self.input_ids is not None:
|
||||
if isinstance(self.input_ids[0], int):
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.is_single = False
|
||||
self.batch_size = len(self.input_ids)
|
||||
self.input_embeds = None
|
||||
else:
|
||||
if isinstance(self.input_embeds[0][0], float):
|
||||
self.is_single = True
|
||||
self.batch_size = 1
|
||||
else:
|
||||
self.batch_size = len(self.input_embeds)
|
||||
|
||||
# Handle parallel sampling
|
||||
# When parallel sampling is used, we always treat the input as a batch.
|
||||
@@ -202,6 +218,8 @@ class TokenizedGenerateReqInput:
|
||||
|
||||
# LoRA related
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
# The input embeds
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[str] = None
|
||||
@@ -218,6 +236,8 @@ class EmbeddingReqInput:
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
# Dummy sampling params for compatibility
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
# Dummy input embeds for compatibility
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
|
||||
@@ -178,6 +178,7 @@ class Req:
|
||||
origin_input_ids: Tuple[int],
|
||||
sampling_params: SamplingParams,
|
||||
lora_path: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -191,6 +192,7 @@ class Req:
|
||||
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
@@ -448,6 +450,7 @@ class ScheduleBatch:
|
||||
|
||||
# Batched arguments to model runner
|
||||
input_ids: torch.Tensor = None
|
||||
input_embeds: torch.Tensor = None
|
||||
req_pool_indices: torch.Tensor = None
|
||||
seq_lens: torch.Tensor = None
|
||||
# The output locations of the KV cache
|
||||
@@ -631,6 +634,9 @@ class ScheduleBatch:
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
|
||||
|
||||
input_embeds = []
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(reqs):
|
||||
already_computed = (
|
||||
req.extend_logprob_start_len + 1 + req.cached_tokens
|
||||
@@ -649,6 +655,11 @@ class ScheduleBatch:
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
|
||||
# If input_embeds are available, store them
|
||||
if req.input_embeds is not None:
|
||||
# If req.input_embeds is already a list, append its content directly
|
||||
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
if req.logprob_start_len >= pre_len:
|
||||
extend_logprob_start_len = min(
|
||||
@@ -671,6 +682,12 @@ class ScheduleBatch:
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.input_embeds = (
|
||||
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
||||
if input_embeds
|
||||
else None
|
||||
)
|
||||
|
||||
self.out_cache_loc = out_cache_loc
|
||||
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
@@ -1053,6 +1070,7 @@ class ScheduleBatch:
|
||||
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
input_embeds=self.input_embeds,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1123,6 +1141,9 @@ class ModelWorkerBatch:
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
# The input Embeds
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
|
||||
@@ -526,12 +526,20 @@ class Scheduler:
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||
# Check if input_embeds is present and create dummy input_ids
|
||||
if recv_req.input_embeds is not None:
|
||||
# Generate fake input_ids based on the length of input_embeds
|
||||
seq_length = len(recv_req.input_embeds)
|
||||
fake_input_ids = [1] * seq_length
|
||||
recv_req.input_ids = fake_input_ids
|
||||
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
if recv_req.session_id is not None:
|
||||
|
||||
@@ -201,8 +201,18 @@ class TokenizerManager:
|
||||
):
|
||||
"""Tokenize one request."""
|
||||
# Tokenize
|
||||
input_embeds = None
|
||||
input_text = obj.text
|
||||
if obj.input_ids is None:
|
||||
if obj.input_embeds is not None:
|
||||
if not self.server_args.disable_radix_cache:
|
||||
raise ValueError(
|
||||
"input_embeds is provided while disable_radix_cache is False. "
|
||||
"Please add `--disable-radix-cach` when you launch the server "
|
||||
"if you want to use input_embeds as inputs."
|
||||
)
|
||||
input_embeds = obj.input_embeds
|
||||
input_ids = obj.input_ids
|
||||
elif obj.input_ids is None:
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
@@ -219,7 +229,7 @@ class TokenizerManager:
|
||||
session_id = obj.session[0] if obj.session else None
|
||||
session_rid = obj.session[1] if obj.session else None
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
@@ -242,7 +252,8 @@ class TokenizerManager:
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
obj.lora_path,
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
session_id=session_id,
|
||||
session_rid=session_rid,
|
||||
)
|
||||
|
||||
@@ -130,6 +130,9 @@ class ForwardBatch:
|
||||
# For LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
|
||||
# For input embeddings
|
||||
input_embeds: Optional[torch.tensor] = None
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -231,6 +234,7 @@ class ForwardBatch:
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
input_embeds=batch.input_embeds,
|
||||
)
|
||||
|
||||
if ret.global_num_tokens is not None:
|
||||
|
||||
@@ -606,9 +606,17 @@ class ModelRunner:
|
||||
def forward_extend(self, forward_batch: ForwardBatch):
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
if forward_batch.input_embeds is None:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
else:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids,
|
||||
forward_batch.positions,
|
||||
forward_batch,
|
||||
input_embeds=forward_batch.input_embeds.bfloat16(),
|
||||
)
|
||||
else:
|
||||
# Only embedding models have get_embedding parameter
|
||||
return self.model.forward(
|
||||
|
||||
Reference in New Issue
Block a user