Support token ids in engine.generate (#1820)
This commit is contained in:
@@ -742,18 +742,20 @@ class Engine:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
lora_path: Optional[List[Optional[str]]] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
# TODO (ByronHsu): refactor to reduce the duplicated code
|
||||
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
@@ -791,8 +793,11 @@ class Engine:
|
||||
|
||||
async def async_generate(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
prompt: Optional[Union[List[str], str]] = None,
|
||||
sampling_params: Optional[Dict] = None,
|
||||
# The token ids for text; one can either specify text or input_ids.
|
||||
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
return_logprob: Optional[Union[List[bool], bool]] = False,
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None,
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
||||
@@ -801,6 +806,7 @@ class Engine:
|
||||
):
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
sampling_params=sampling_params,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
|
||||
Reference in New Issue
Block a user