Allow input_ids in the input of the /generate endpoint (#363)

This commit is contained in:
Shannon Shen
2024-05-12 12:29:00 -10:00
committed by GitHub
parent 6e09cf6a15
commit 04c0b21488
5 changed files with 43 additions and 10 deletions

View File

@@ -147,11 +147,15 @@ class TokenizerManager:
if self.to_create_loop:
await self.create_handle_loop()
is_single = isinstance(obj.text, str)
is_single = obj.is_single
if is_single:
rid = obj.rid
input_ids = self.tokenizer.encode(obj.text)
if obj.input_ids is None:
input_ids = self.tokenizer.encode(obj.text)
else:
input_ids = obj.input_ids
sampling_params = SamplingParams(**obj.sampling_params)
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
@@ -204,10 +208,22 @@ class TokenizerManager:
event.clear()
else:
assert obj.stream is False
bs = len(obj.text)
if obj.input_ids is None:
bs = len(obj.text)
else:
bs = len(obj.input_ids)
for i in range(bs):
rid = obj.rid[i]
input_ids = self.tokenizer.encode(obj.text[i])
if obj.input_ids is None:
input_text = obj.text[i]
input_ids = self.tokenizer.encode(obj.text[i])
else:
input_text = None
input_ids = obj.input_ids[i]
sampling_params = SamplingParams(**obj.sampling_params[i])
if sampling_params.max_new_tokens != 0:
sampling_params.normalize(self.tokenizer)
@@ -220,7 +236,7 @@ class TokenizerManager:
)
tokenized_obj = TokenizedGenerateReqInput(
rid=rid,
input_text=obj.text[i],
input_text=input_text,
input_ids=input_ids,
pixel_values=pixel_values,
image_hash=image_hash,