[Feat] Expose logprob options to sgl.gen API (#503)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
胡译文
2024-07-09 15:35:39 +08:00
committed by GitHub
parent d557e9f3b7
commit 02b7258658
7 changed files with 239 additions and 43 deletions

View File

@@ -668,6 +668,10 @@ class StreamExecutor:
"frequency_penalty",
"presence_penalty",
"ignore_eos",
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
"dtype",
"regex",
]:

View File

@@ -23,6 +23,10 @@ class SglSamplingParams:
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
ignore_eos: bool = False
return_logprob: Optional[bool] = None
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
@@ -37,6 +41,11 @@ class SglSamplingParams:
self.top_k,
self.frequency_penalty,
self.presence_penalty,
self.ignore_eos,
self.return_logprob,
self.logprob_start_len,
self.top_logprobs_num,
self.return_text_in_logprobs,
)
def to_openai_kwargs(self):
@@ -139,6 +148,10 @@ class SglFunction:
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
stream: bool = False,
backend=None,
**kwargs,
@@ -154,6 +167,10 @@ class SglFunction:
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
)
backend = backend or global_config.default_backend
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
@@ -170,6 +187,10 @@ class SglFunction:
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
backend=None,
num_threads: Union[str, int] = "auto",
progress_bar: bool = False,
@@ -203,6 +224,10 @@ class SglFunction:
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
)
backend = backend or global_config.default_backend
return run_program_batch(
@@ -350,7 +375,7 @@ class SglArgument(SglExpr):
class SglImage(SglExpr):
def __init__(self, path):
def __init__(self, path: str):
self.path = path
def __repr__(self) -> str:
@@ -358,7 +383,7 @@ class SglImage(SglExpr):
class SglVideo(SglExpr):
def __init__(self, path, num_frames):
def __init__(self, path: str, num_frames: int):
self.path = path
self.num_frames = num_frames
@@ -369,18 +394,23 @@ class SglVideo(SglExpr):
class SglGen(SglExpr):
def __init__(
self,
name,
max_new_tokens,
stop,
temperature,
top_p,
top_k,
frequency_penalty,
presence_penalty,
ignore_eos,
dtype,
regex,
name: Optional[str] = None,
max_new_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
regex: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
super().__init__()
self.name = name
self.sampling_params = SglSamplingParams(
@@ -392,6 +422,10 @@ class SglGen(SglExpr):
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
dtype=dtype,
regex=regex,
)
@@ -401,7 +435,7 @@ class SglGen(SglExpr):
class SglConstantText(SglExpr):
def __init__(self, value):
def __init__(self, value: str):
super().__init__()
self.value = value
@@ -410,7 +444,7 @@ class SglConstantText(SglExpr):
class SglRoleBegin(SglExpr):
def __init__(self, role):
def __init__(self, role: str):
super().__init__()
self.role = role
@@ -419,7 +453,7 @@ class SglRoleBegin(SglExpr):
class SglRoleEnd(SglExpr):
def __init__(self, role):
def __init__(self, role: str):
super().__init__()
self.role = role
@@ -428,7 +462,7 @@ class SglRoleEnd(SglExpr):
class SglSelect(SglExpr):
def __init__(self, name, choices, temperature):
def __init__(self, name: str, choices: List[str], temperature: float):
super().__init__()
self.name = name
self.choices = choices
@@ -439,7 +473,7 @@ class SglSelect(SglExpr):
class SglFork(SglExpr):
def __init__(self, number, position_ids_offset=None):
def __init__(self, number: int, position_ids_offset=None):
super().__init__()
self.number = number
self.position_ids_offset = position_ids_offset
@@ -452,7 +486,7 @@ class SglFork(SglExpr):
class SglGetForkItem(SglExpr):
def __init__(self, index):
def __init__(self, index: int):
super().__init__()
self.index = index
@@ -461,7 +495,7 @@ class SglGetForkItem(SglExpr):
class SglVariable(SglExpr):
def __init__(self, name, source):
def __init__(self, name: str, source):
super().__init__()
self.name = name
self.source = source
@@ -471,7 +505,7 @@ class SglVariable(SglExpr):
class SglVarScopeBegin(SglExpr):
def __init__(self, name):
def __init__(self, name: str):
super().__init__()
self.name = name
@@ -480,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
class SglVarScopeEnd(SglExpr):
def __init__(self, name):
def __init__(self, name: str):
super().__init__()
self.name = name
@@ -502,4 +536,4 @@ class SglCommitLazy(SglExpr):
super().__init__()
def __repr__(self):
return f"CommitLazy()"
return "CommitLazy()"