Input_embeds support (#2052)
This commit is contained in:
@@ -201,8 +201,18 @@ class TokenizerManager:
|
||||
):
|
||||
"""Tokenize one request."""
|
||||
# Tokenize
|
||||
input_embeds = None
|
||||
input_text = obj.text
|
||||
if obj.input_ids is None:
|
||||
if obj.input_embeds is not None:
|
||||
if not self.server_args.disable_radix_cache:
|
||||
raise ValueError(
|
||||
"input_embeds is provided while disable_radix_cache is False. "
|
||||
"Please add `--disable-radix-cach` when you launch the server "
|
||||
"if you want to use input_embeds as inputs."
|
||||
)
|
||||
input_embeds = obj.input_embeds
|
||||
input_ids = obj.input_ids
|
||||
elif obj.input_ids is None:
|
||||
input_ids = self.tokenizer.encode(input_text)
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
@@ -219,7 +229,7 @@ class TokenizerManager:
|
||||
session_id = obj.session[0] if obj.session else None
|
||||
session_rid = obj.session[1] if obj.session else None
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
f"The input ({len(input_ids)} tokens) is longer than the "
|
||||
f"model's context length ({self.context_len} tokens)."
|
||||
@@ -242,7 +252,8 @@ class TokenizerManager:
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
obj.stream,
|
||||
obj.lora_path,
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
session_id=session_id,
|
||||
session_rid=session_rid,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user