[Feat] Expose logprob options to sgl.gen API (#503)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
@@ -67,10 +67,16 @@ def gen(
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
regex: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
||||
|
||||
if choices:
|
||||
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
||||
|
||||
@@ -91,6 +97,10 @@ def gen(
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
dtype,
|
||||
regex,
|
||||
)
|
||||
@@ -106,6 +116,10 @@ def gen_int(
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -117,6 +131,10 @@ def gen_int(
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
int,
|
||||
None,
|
||||
)
|
||||
@@ -132,6 +150,10 @@ def gen_string(
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
):
|
||||
return SglGen(
|
||||
name,
|
||||
@@ -143,6 +165,10 @@ def gen_string(
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
return_logprob,
|
||||
logprob_start_len,
|
||||
top_logprobs_num,
|
||||
return_text_in_logprobs,
|
||||
str,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@@ -37,8 +38,7 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
self.model_info["model_path"])
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
@@ -124,6 +124,11 @@ class RuntimeEndpoint(BaseBackend):
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
@@ -166,6 +171,11 @@ class RuntimeEndpoint(BaseBackend):
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
|
||||
@@ -668,6 +668,10 @@ class StreamExecutor:
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"ignore_eos",
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
"dtype",
|
||||
"regex",
|
||||
]:
|
||||
|
||||
@@ -23,6 +23,10 @@ class SglSamplingParams:
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
return_logprob: Optional[bool] = None
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
@@ -37,6 +41,11 @@ class SglSamplingParams:
|
||||
self.top_k,
|
||||
self.frequency_penalty,
|
||||
self.presence_penalty,
|
||||
self.ignore_eos,
|
||||
self.return_logprob,
|
||||
self.logprob_start_len,
|
||||
self.top_logprobs_num,
|
||||
self.return_text_in_logprobs,
|
||||
)
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
@@ -139,6 +148,10 @@ class SglFunction:
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
stream: bool = False,
|
||||
backend=None,
|
||||
**kwargs,
|
||||
@@ -154,6 +167,10 @@ class SglFunction:
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||
@@ -170,6 +187,10 @@ class SglFunction:
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
ignore_eos: bool = False,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
backend=None,
|
||||
num_threads: Union[str, int] = "auto",
|
||||
progress_bar: bool = False,
|
||||
@@ -203,6 +224,10 @@ class SglFunction:
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
)
|
||||
backend = backend or global_config.default_backend
|
||||
return run_program_batch(
|
||||
@@ -350,7 +375,7 @@ class SglArgument(SglExpr):
|
||||
|
||||
|
||||
class SglImage(SglExpr):
|
||||
def __init__(self, path):
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -358,7 +383,7 @@ class SglImage(SglExpr):
|
||||
|
||||
|
||||
class SglVideo(SglExpr):
|
||||
def __init__(self, path, num_frames):
|
||||
def __init__(self, path: str, num_frames: int):
|
||||
self.path = path
|
||||
self.num_frames = num_frames
|
||||
|
||||
@@ -369,18 +394,23 @@ class SglVideo(SglExpr):
|
||||
class SglGen(SglExpr):
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
max_new_tokens,
|
||||
stop,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
frequency_penalty,
|
||||
presence_penalty,
|
||||
ignore_eos,
|
||||
dtype,
|
||||
regex,
|
||||
name: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
ignore_eos: Optional[bool] = None,
|
||||
return_logprob: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
regex: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.sampling_params = SglSamplingParams(
|
||||
@@ -392,6 +422,10 @@ class SglGen(SglExpr):
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
ignore_eos=ignore_eos,
|
||||
return_logprob=return_logprob,
|
||||
logprob_start_len=logprob_start_len,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
)
|
||||
@@ -401,7 +435,7 @@ class SglGen(SglExpr):
|
||||
|
||||
|
||||
class SglConstantText(SglExpr):
|
||||
def __init__(self, value):
|
||||
def __init__(self, value: str):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
|
||||
@@ -410,7 +444,7 @@ class SglConstantText(SglExpr):
|
||||
|
||||
|
||||
class SglRoleBegin(SglExpr):
|
||||
def __init__(self, role):
|
||||
def __init__(self, role: str):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
@@ -419,7 +453,7 @@ class SglRoleBegin(SglExpr):
|
||||
|
||||
|
||||
class SglRoleEnd(SglExpr):
|
||||
def __init__(self, role):
|
||||
def __init__(self, role: str):
|
||||
super().__init__()
|
||||
self.role = role
|
||||
|
||||
@@ -428,7 +462,7 @@ class SglRoleEnd(SglExpr):
|
||||
|
||||
|
||||
class SglSelect(SglExpr):
|
||||
def __init__(self, name, choices, temperature):
|
||||
def __init__(self, name: str, choices: List[str], temperature: float):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.choices = choices
|
||||
@@ -439,7 +473,7 @@ class SglSelect(SglExpr):
|
||||
|
||||
|
||||
class SglFork(SglExpr):
|
||||
def __init__(self, number, position_ids_offset=None):
|
||||
def __init__(self, number: int, position_ids_offset=None):
|
||||
super().__init__()
|
||||
self.number = number
|
||||
self.position_ids_offset = position_ids_offset
|
||||
@@ -452,7 +486,7 @@ class SglFork(SglExpr):
|
||||
|
||||
|
||||
class SglGetForkItem(SglExpr):
|
||||
def __init__(self, index):
|
||||
def __init__(self, index: int):
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
@@ -461,7 +495,7 @@ class SglGetForkItem(SglExpr):
|
||||
|
||||
|
||||
class SglVariable(SglExpr):
|
||||
def __init__(self, name, source):
|
||||
def __init__(self, name: str, source):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.source = source
|
||||
@@ -471,7 +505,7 @@ class SglVariable(SglExpr):
|
||||
|
||||
|
||||
class SglVarScopeBegin(SglExpr):
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
@@ -480,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
|
||||
|
||||
|
||||
class SglVarScopeEnd(SglExpr):
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
|
||||
@@ -502,4 +536,4 @@ class SglCommitLazy(SglExpr):
|
||||
super().__init__()
|
||||
|
||||
def __repr__(self):
|
||||
return f"CommitLazy()"
|
||||
return "CommitLazy()"
|
||||
|
||||
@@ -333,17 +333,18 @@ class TokenizerManager:
|
||||
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"][
|
||||
"prefill_top_logprobs"
|
||||
] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"][
|
||||
"decode_top_logprobs"
|
||||
] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"][
|
||||
"prefill_top_logprobs"
|
||||
] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"][
|
||||
"decode_top_logprobs"
|
||||
] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
return ret
|
||||
|
||||
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
||||
@@ -383,7 +384,7 @@ def get_pixel_values(
|
||||
try:
|
||||
processor = processor or global_processor
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size != None:
|
||||
if image_size is not None:
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
|
||||
Reference in New Issue
Block a user