Input_embeds support (#2052)

This commit is contained in:
Rin Intachuen
2024-11-25 19:35:04 -05:00
committed by GitHub
parent 1f76fc6e3f
commit 1aea19f64b
9 changed files with 204 additions and 15 deletions

View File

@@ -201,8 +201,18 @@ class TokenizerManager:
):
"""Tokenize one request."""
# Tokenize
input_embeds = None
input_text = obj.text
if obj.input_ids is None:
if obj.input_embeds is not None:
if not self.server_args.disable_radix_cache:
raise ValueError(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cach` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds = obj.input_embeds
input_ids = obj.input_ids
elif obj.input_ids is None:
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids
@@ -219,7 +229,7 @@ class TokenizerManager:
session_id = obj.session[0] if obj.session else None
session_rid = obj.session[1] if obj.session else None
if len(input_ids) >= self.context_len:
if obj.input_ids is not None and len(input_ids) >= self.context_len:
raise ValueError(
f"The input ({len(input_ids)} tokens) is longer than the "
f"model's context length ({self.context_len} tokens)."
@@ -242,7 +252,8 @@ class TokenizerManager:
logprob_start_len,
top_logprobs_num,
obj.stream,
obj.lora_path,
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_id=session_id,
session_rid=session_rid,
)