Feat: add alternative choices selection methods (#835)

This commit is contained in:
Aidan Cooper
2024-08-05 11:27:49 +01:00
committed by GitHub
parent b216a545b3
commit 94e0115186
10 changed files with 426 additions and 48 deletions

View File

@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
from sglang.lang.ir import (
SglExpr,
SglExprList,
@@ -73,12 +74,18 @@ def gen(
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
if choices:
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
return SglSelect(
name,
choices,
0.0 if temperature is None else temperature,
token_length_normalized if choices_method is None else choices_method,
)
# check regex is valid
if regex is not None:
@@ -186,9 +193,10 @@ def select(
name: Optional[str] = None,
choices: Optional[List[str]] = None,
temperature: float = 0.0,
choices_method: ChoicesSamplingMethod = token_length_normalized,
):
assert choices is not None
return SglSelect(name, choices, temperature)
return SglSelect(name, choices, temperature, choices_method)
def _role_common(name: str, expr: Optional[SglExpr] = None):