Feat: add alternative choices selection methods (#835)

This commit is contained in:
Aidan Cooper
2024-08-05 11:27:49 +01:00
committed by GitHub
parent b216a545b3
commit 94e0115186
10 changed files with 426 additions and 48 deletions

View File

@@ -22,6 +22,11 @@ from sglang.api import (
user_end,
video,
)
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
# SGLang DSL APIs
__all__ = [
@@ -45,6 +50,9 @@ __all__ = [
"user_begin",
"user_end",
"video",
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
]
# Global Configurations

View File

@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
from sglang.lang.ir import (
SglExpr,
SglExprList,
@@ -73,12 +74,18 @@ def gen(
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
if choices:
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
return SglSelect(
name,
choices,
0.0 if temperature is None else temperature,
token_length_normalized if choices_method is None else choices_method,
)
# check regex is valid
if regex is not None:
@@ -186,9 +193,10 @@ def select(
name: Optional[str] = None,
choices: Optional[List[str]] = None,
temperature: float = 0.0,
choices_method: ChoicesSamplingMethod = token_length_normalized,
):
assert choices is not None
return SglSelect(name, choices, temperature)
return SglSelect(name, choices, temperature, choices_method)
def _role_common(name: str, expr: Optional[SglExpr] = None):

View File

@@ -1,6 +1,7 @@
from typing import Callable, List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
@@ -64,7 +65,8 @@ class BaseBackend:
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: Optional[ChoicesSamplingMethod] = None,
) -> ChoicesDecision:
raise NotImplementedError()
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):

View File

@@ -8,6 +8,7 @@ import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
"""Note: `choices_method` is not used by the OpenAI backend."""
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)]
return decision, scores, None, None
return ChoicesDecision(
decision=choices[np.argmax(scores)],
meta_info={"scores": scores},
)
def openai_completion(

View File

@@ -1,17 +1,21 @@
import json
from typing import List, Optional
import numpy as np
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.choices import (
ChoicesDecision,
ChoicesSamplingMethod,
token_length_normalized,
)
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend):
def __init__(
self,
base_url: str,
@@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend):
s: StreamExecutor,
choices: List[str],
temperature: float,
):
choices_method: ChoicesSamplingMethod,
) -> ChoicesDecision:
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
prompt_len = res.json()["meta_info"]["prompt_tokens"]
obj = self._generate_http_request(s, data)
prompt_len = obj["meta_info"]["prompt_tokens"]
# Compute logprob
data = {
@@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend):
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
obj = self._generate_http_request(s, data)
normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
decision = choices[np.argmax(normalized_prompt_logprobs)]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
return (
decision,
normalized_prompt_logprobs,
input_token_logprobs,
output_token_logprobs,
# Compute unconditional logprobs if required
if choices_method.requires_unconditional_logprobs:
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
data = {
"input_ids": input_ids,
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
}
obj = self._generate_http_request(s, data)
unconditional_token_logprobs = [
r["meta_info"]["input_token_logprobs"] for r in obj
]
else:
unconditional_token_logprobs = None
return choices_method(
choices=choices,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
output_token_logprobs=output_token_logprobs,
unconditional_token_logprobs=unconditional_token_logprobs,
)
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
@@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend):
)
self._assert_success(res)
def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."

View File

@@ -0,0 +1,164 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import numpy as np
@dataclass
class ChoicesDecision:
decision: str
meta_info: Dict[str, Any] | None = None
class ChoicesSamplingMethod(ABC):
@property
def requires_unconditional_logprobs(self) -> bool:
return False
@abstractmethod
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision: ...
class TokenLengthNormalized(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest token length normalized prompt logprob."""
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
token_length_normalized = TokenLengthNormalized()
class GreedyTokenSelection(ChoicesSamplingMethod):
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option based on greedy logprob selection. For overlapping options
where one option is a subset of a longer option, extend the shorter option using
its average logprob for comparison against the longer option."""
num_options = len(choices)
max_tokens = max(len(option) for option in input_token_logprobs)
logprob_matrix = self._build_logprob_matrix(
input_token_logprobs, max_tokens, num_options
)
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
best_choice = choices[remaining[0]]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"greedy_logprob_matrix": logprob_matrix.tolist(),
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
logprob_matrix = np.zeros((num_options, max_tokens))
for i, option in enumerate(input_token_logprobs):
actual_logprobs = [token[0] for token in option]
avg_logprob = np.mean(actual_logprobs)
logprob_matrix[i, : len(option)] = actual_logprobs
if len(option) < max_tokens:
logprob_matrix[i, len(option) :] = avg_logprob
return logprob_matrix
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
remaining = np.arange(num_options)
for j in range(max_tokens):
max_logprob = np.max(logprob_matrix[remaining, j])
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
if len(remaining) == 1:
break
return remaining
greedy_token_selection = GreedyTokenSelection()
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
@property
def requires_unconditional_logprobs(self) -> bool:
return True
def __call__(
self,
*,
choices: List[str],
normalized_prompt_logprobs: List[float],
input_token_logprobs: List[List[Any]],
output_token_logprobs: List[List[Any]],
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
) -> ChoicesDecision:
"""Select the option with the highest average token logprob once normalized by
the unconditional token logprobs.
The first unconditional token logprob is assumed to be None. If so, it is
replaced with 0 for the purposes of normalization."""
if unconditional_token_logprobs is None:
raise ValueError(
"Unconditional token logprobs are required for this method."
)
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
input_token_logprobs, unconditional_token_logprobs
)
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
meta_info = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
"unconditional_token_logprobs": unconditional_token_logprobs,
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
}
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
normalized_unconditional_prompt_logprobs = []
for inputs, unconditionals in zip(
input_token_logprobs, unconditional_token_logprobs
):
inputs_logprobs = np.array([token[0] for token in inputs])
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
normalized_unconditional_prompt_logprobs.append(
float(np.mean(inputs_logprobs - unconditionals_logprobs))
)
return normalized_unconditional_prompt_logprobs
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()

View File

@@ -538,24 +538,17 @@ class StreamExecutor:
self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect):
(
decision,
normalized_prompt_logprobs,
input_token_logprobs,
output_token_logprobs,
) = self.backend.select(self, expr.choices, expr.temperature)
choices_decision = self.backend.select(
self, expr.choices, expr.temperature, expr.choices_method
)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
self.variables[name] = choices_decision.decision
self.meta_info[name] = choices_decision.meta_info
self.variable_event[name].set()
if self.stream_var_event:
self.stream_var_event[name].set()
self.text_ += decision
self.text_ += choices_decision.decision
def _execute_variable(self, expr: SglVariable):
src_executor = expr.source_stream_executor

View File

@@ -6,6 +6,7 @@ import warnings
from typing import List, Optional, Union
from sglang.global_config import global_config
from sglang.lang.choices import ChoicesSamplingMethod
REGEX_INT = r"[-+]?[0-9]+"
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
@@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr):
class SglSelect(SglExpr):
def __init__(self, name: str, choices: List[str], temperature: float):
def __init__(
self,
name: str,
choices: List[str],
temperature: float,
choices_method: ChoicesSamplingMethod,
):
super().__init__()
self.name = name
self.choices = choices
self.temperature = temperature
self.choices_method = choices_method
def __repr__(self):
return f"Select({self.name}, choices={self.choices})"
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
class SglFork(SglExpr):