[Feat] Expose logprob options to sgl.gen API (#503)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
胡译文
2024-07-09 15:35:39 +08:00
committed by GitHub
parent d557e9f3b7
commit 02b7258658
7 changed files with 239 additions and 43 deletions

View File

@@ -67,10 +67,16 @@ def gen(
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
regex: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
if choices:
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
@@ -91,6 +97,10 @@ def gen(
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
dtype,
regex,
)
@@ -106,6 +116,10 @@ def gen_int(
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
):
return SglGen(
name,
@@ -117,6 +131,10 @@ def gen_int(
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
int,
None,
)
@@ -132,6 +150,10 @@ def gen_string(
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
):
return SglGen(
name,
@@ -143,6 +165,10 @@ def gen_string(
frequency_penalty,
presence_penalty,
ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
str,
None,
)