Allow input_ids in the input of the /generate endpoint (#363)
This commit is contained in:
@@ -30,7 +30,7 @@ if __name__ == "__main__":
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
"text": f"{a}, ",
|
"input_ids": [[1,2,3], [1,2,3]],
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
|
|||||||
class GenerateReqInput:
|
class GenerateReqInput:
|
||||||
# The input prompt
|
# The input prompt
|
||||||
text: Union[List[str], str]
|
text: Union[List[str], str]
|
||||||
|
# The token ids for text; one can either specify text or input_ids
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
# The image input
|
# The image input
|
||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
# The sampling_params
|
# The sampling_params
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from sglang.srt.sampling_params import SamplingParams
|
|||||||
@dataclass
|
@dataclass
|
||||||
class GenerateReqInput:
|
class GenerateReqInput:
|
||||||
# The input prompt
|
# The input prompt
|
||||||
text: Union[List[str], str]
|
text: Optional[Union[List[str], str]] = None
|
||||||
|
# The token ids for text; one can either specify text or input_ids
|
||||||
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
||||||
# The image input
|
# The image input
|
||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
# The sampling_params
|
# The sampling_params
|
||||||
@@ -28,7 +30,17 @@ class GenerateReqInput:
|
|||||||
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
is_single = isinstance(self.text, str)
|
|
||||||
|
if self.text is None:
|
||||||
|
assert self.input_ids is not None, "Either text or input_ids should be provided"
|
||||||
|
else:
|
||||||
|
assert self.input_ids is None, "Either text or input_ids should be provided"
|
||||||
|
|
||||||
|
if self.text is not None:
|
||||||
|
is_single = isinstance(self.text, str)
|
||||||
|
else:
|
||||||
|
is_single = isinstance(self.input_ids[0], int)
|
||||||
|
self.is_single = is_single
|
||||||
|
|
||||||
if is_single:
|
if is_single:
|
||||||
if self.sampling_params is None:
|
if self.sampling_params is None:
|
||||||
@@ -42,7 +54,7 @@ class GenerateReqInput:
|
|||||||
if self.top_logprobs_num is None:
|
if self.top_logprobs_num is None:
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
else:
|
else:
|
||||||
num = len(self.text)
|
num = len(self.text) if self.text is not None else len(self.input_ids)
|
||||||
|
|
||||||
if self.image_data is None:
|
if self.image_data is None:
|
||||||
self.image_data = [None] * num
|
self.image_data = [None] * num
|
||||||
|
|||||||
@@ -85,6 +85,9 @@ class Req:
|
|||||||
)
|
)
|
||||||
if first_token.startswith("▁"):
|
if first_token.startswith("▁"):
|
||||||
old_output_str = " " + old_output_str
|
old_output_str = " " + old_output_str
|
||||||
|
if self.input_text is None:
|
||||||
|
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
|
||||||
|
self.input_text = self.tokenizer.decode(self.input_ids)
|
||||||
new_input_string = (
|
new_input_string = (
|
||||||
self.input_text
|
self.input_text
|
||||||
+ self.output_and_jump_forward_str
|
+ self.output_and_jump_forward_str
|
||||||
|
|||||||
@@ -147,11 +147,15 @@ class TokenizerManager:
|
|||||||
if self.to_create_loop:
|
if self.to_create_loop:
|
||||||
await self.create_handle_loop()
|
await self.create_handle_loop()
|
||||||
|
|
||||||
is_single = isinstance(obj.text, str)
|
is_single = obj.is_single
|
||||||
|
|
||||||
if is_single:
|
if is_single:
|
||||||
rid = obj.rid
|
rid = obj.rid
|
||||||
input_ids = self.tokenizer.encode(obj.text)
|
|
||||||
|
if obj.input_ids is None:
|
||||||
|
input_ids = self.tokenizer.encode(obj.text)
|
||||||
|
else:
|
||||||
|
input_ids = obj.input_ids
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params)
|
sampling_params = SamplingParams(**obj.sampling_params)
|
||||||
if sampling_params.max_new_tokens != 0:
|
if sampling_params.max_new_tokens != 0:
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
@@ -204,10 +208,22 @@ class TokenizerManager:
|
|||||||
event.clear()
|
event.clear()
|
||||||
else:
|
else:
|
||||||
assert obj.stream is False
|
assert obj.stream is False
|
||||||
bs = len(obj.text)
|
|
||||||
|
if obj.input_ids is None:
|
||||||
|
bs = len(obj.text)
|
||||||
|
else:
|
||||||
|
bs = len(obj.input_ids)
|
||||||
|
|
||||||
for i in range(bs):
|
for i in range(bs):
|
||||||
rid = obj.rid[i]
|
rid = obj.rid[i]
|
||||||
input_ids = self.tokenizer.encode(obj.text[i])
|
|
||||||
|
if obj.input_ids is None:
|
||||||
|
input_text = obj.text[i]
|
||||||
|
input_ids = self.tokenizer.encode(obj.text[i])
|
||||||
|
else:
|
||||||
|
input_text = None
|
||||||
|
input_ids = obj.input_ids[i]
|
||||||
|
|
||||||
sampling_params = SamplingParams(**obj.sampling_params[i])
|
sampling_params = SamplingParams(**obj.sampling_params[i])
|
||||||
if sampling_params.max_new_tokens != 0:
|
if sampling_params.max_new_tokens != 0:
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
@@ -220,7 +236,7 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
input_text=obj.text[i],
|
input_text=input_text,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_hash=image_hash,
|
image_hash=image_hash,
|
||||||
|
|||||||
Reference in New Issue
Block a user