Support token ids in engine.generate (#1820)

This commit is contained in:
Byron Hsu
2024-10-27 14:02:34 -07:00
committed by GitHub
parent c77762d57f
commit 6fcd6d7d6d
3 changed files with 72 additions and 4 deletions

View File

@@ -742,18 +742,20 @@ class Engine:
def generate(
self,
prompt: Union[str, List[str]],
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Dict] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
stream: bool = False,
):
# TODO (ByronHsu): refactor to reduce the duplicated code
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
@@ -791,8 +793,11 @@ class Engine:
async def async_generate(
self,
prompt: Union[str, List[str]],
# The input prompt. It can be a single prompt or a batch of prompts.
prompt: Optional[Union[List[str], str]] = None,
sampling_params: Optional[Dict] = None,
# The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -801,6 +806,7 @@ class Engine:
):
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
sampling_params=sampling_params,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,