Allow input_ids in the input of the /generate endpoint (#363)
This commit is contained in:
@@ -147,11 +147,15 @@ class TokenizerManager:
|
||||
if self.to_create_loop:
|
||||
await self.create_handle_loop()
|
||||
|
||||
is_single = isinstance(obj.text, str)
|
||||
|
||||
is_single = obj.is_single
|
||||
if is_single:
|
||||
rid = obj.rid
|
||||
input_ids = self.tokenizer.encode(obj.text)
|
||||
|
||||
if obj.input_ids is None:
|
||||
input_ids = self.tokenizer.encode(obj.text)
|
||||
else:
|
||||
input_ids = obj.input_ids
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params)
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
@@ -204,10 +208,22 @@ class TokenizerManager:
|
||||
event.clear()
|
||||
else:
|
||||
assert obj.stream is False
|
||||
bs = len(obj.text)
|
||||
|
||||
if obj.input_ids is None:
|
||||
bs = len(obj.text)
|
||||
else:
|
||||
bs = len(obj.input_ids)
|
||||
|
||||
for i in range(bs):
|
||||
rid = obj.rid[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
|
||||
if obj.input_ids is None:
|
||||
input_text = obj.text[i]
|
||||
input_ids = self.tokenizer.encode(obj.text[i])
|
||||
else:
|
||||
input_text = None
|
||||
input_ids = obj.input_ids[i]
|
||||
|
||||
sampling_params = SamplingParams(**obj.sampling_params[i])
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
@@ -220,7 +236,7 @@ class TokenizerManager:
|
||||
)
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text=obj.text[i],
|
||||
input_text=input_text,
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
image_hash=image_hash,
|
||||
|
||||
Reference in New Issue
Block a user