[Feat] Expose logprob options to sgl.gen API (#503)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -67,10 +67,16 @@ def gen(
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
regex: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
||||
|
||||
if choices:
|
||||
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
||||
|
||||
@@ -91,6 +97,10 @@ def gen(
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
dtype,
|
||||
regex,
|
||||
)
|
||||
@@ -106,6 +116,10 @@ def gen_int(
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -117,6 +131,10 @@ def gen_int(
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
@@ -132,6 +150,10 @@ def gen_string(
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -143,6 +165,10 @@ def gen_string(
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user