Feat: add alternative choices selection methods (#835)
This commit is contained in:
@@ -22,6 +22,11 @@ from sglang.api import (
|
||||
user_end,
|
||||
video,
|
||||
)
|
||||
from sglang.lang.choices import (
|
||||
greedy_token_selection,
|
||||
token_length_normalized,
|
||||
unconditional_likelihood_normalized,
|
||||
)
|
||||
|
||||
# SGLang DSL APIs
|
||||
__all__ = [
|
||||
@@ -45,6 +50,9 @@ __all__ = [
|
||||
"user_begin",
|
||||
"user_end",
|
||||
"video",
|
||||
"greedy_token_selection",
|
||||
"token_length_normalized",
|
||||
"unconditional_likelihood_normalized",
|
||||
]
|
||||
|
||||
# Global Configurations
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized
|
||||
from sglang.lang.ir import (
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
@@ -73,12 +74,18 @@ def gen(
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
regex: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
||||
|
||||
if choices:
|
||||
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
||||
return SglSelect(
|
||||
name,
|
||||
choices,
|
||||
0.0 if temperature is None else temperature,
|
||||
token_length_normalized if choices_method is None else choices_method,
|
||||
)
|
||||
|
||||
# check regex is valid
|
||||
if regex is not None:
|
||||
@@ -186,9 +193,10 @@ def select(
|
||||
name: Optional[str] = None,
|
||||
choices: Optional[List[str]] = None,
|
||||
temperature: float = 0.0,
|
||||
choices_method: ChoicesSamplingMethod = token_length_normalized,
|
||||
):
|
||||
assert choices is not None
|
||||
return SglSelect(name, choices, temperature)
|
||||
return SglSelect(name, choices, temperature, choices_method)
|
||||
|
||||
|
||||
def _role_common(name: str, expr: Optional[SglExpr] = None):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
@@ -64,7 +65,8 @@ class BaseBackend:
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
) -> ChoicesDecision:
|
||||
raise NotImplementedError()
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
|
||||
@@ -8,6 +8,7 @@ import numpy as np
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
||||
from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
@@ -296,7 +297,9 @@ class OpenAI(BaseBackend):
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
) -> ChoicesDecision:
|
||||
"""Note: `choices_method` is not used by the OpenAI backend."""
|
||||
if self.is_chat_model:
|
||||
raise NotImplementedError(
|
||||
"select/choices is not supported for chat models. "
|
||||
@@ -354,8 +357,10 @@ class OpenAI(BaseBackend):
|
||||
|
||||
prompt_tokens.append(ret_token)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores, None, None
|
||||
return ChoicesDecision(
|
||||
decision=choices[np.argmax(scores)],
|
||||
meta_info={"scores": scores},
|
||||
)
|
||||
|
||||
|
||||
def openai_completion(
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.choices import (
|
||||
ChoicesDecision,
|
||||
ChoicesSamplingMethod,
|
||||
token_length_normalized,
|
||||
)
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend):
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
) -> ChoicesDecision:
|
||||
assert temperature <= 1e-5
|
||||
|
||||
# Cache common prefix
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||
obj = self._generate_http_request(s, data)
|
||||
prompt_len = obj["meta_info"]["prompt_tokens"]
|
||||
|
||||
# Compute logprob
|
||||
data = {
|
||||
@@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend):
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": max(prompt_len - 2, 0),
|
||||
}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
obj = res.json()
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
normalized_prompt_logprobs = [
|
||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||
]
|
||||
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||
|
||||
return (
|
||||
decision,
|
||||
normalized_prompt_logprobs,
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
# Compute unconditional logprobs if required
|
||||
if choices_method.requires_unconditional_logprobs:
|
||||
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
||||
data = {
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"return_logprob": True,
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
unconditional_token_logprobs = [
|
||||
r["meta_info"]["input_token_logprobs"] for r in obj
|
||||
]
|
||||
else:
|
||||
unconditional_token_logprobs = None
|
||||
|
||||
return choices_method(
|
||||
choices=choices,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
output_token_logprobs=output_token_logprobs,
|
||||
unconditional_token_logprobs=unconditional_token_logprobs,
|
||||
)
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
@@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend):
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _generate_http_request(self, s: StreamExecutor, data):
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
return res.json()
|
||||
|
||||
def _add_images(self, s: StreamExecutor, data):
|
||||
if s.images_:
|
||||
assert len(s.images_) == 1, "Only support one image."
|
||||
|
||||
164
python/sglang/lang/choices.py
Normal file
164
python/sglang/lang/choices.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChoicesDecision:
|
||||
decision: str
|
||||
meta_info: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChoicesSamplingMethod(ABC):
|
||||
|
||||
@property
|
||||
def requires_unconditional_logprobs(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision: ...
|
||||
|
||||
|
||||
class TokenLengthNormalized(ChoicesSamplingMethod):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option with the highest token length normalized prompt logprob."""
|
||||
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
|
||||
token_length_normalized = TokenLengthNormalized()
|
||||
|
||||
|
||||
class GreedyTokenSelection(ChoicesSamplingMethod):
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option based on greedy logprob selection. For overlapping options
|
||||
where one option is a subset of a longer option, extend the shorter option using
|
||||
its average logprob for comparison against the longer option."""
|
||||
|
||||
num_options = len(choices)
|
||||
max_tokens = max(len(option) for option in input_token_logprobs)
|
||||
logprob_matrix = self._build_logprob_matrix(
|
||||
input_token_logprobs, max_tokens, num_options
|
||||
)
|
||||
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
|
||||
|
||||
best_choice = choices[remaining[0]]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
"greedy_logprob_matrix": logprob_matrix.tolist(),
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
|
||||
logprob_matrix = np.zeros((num_options, max_tokens))
|
||||
for i, option in enumerate(input_token_logprobs):
|
||||
actual_logprobs = [token[0] for token in option]
|
||||
avg_logprob = np.mean(actual_logprobs)
|
||||
logprob_matrix[i, : len(option)] = actual_logprobs
|
||||
if len(option) < max_tokens:
|
||||
logprob_matrix[i, len(option) :] = avg_logprob
|
||||
return logprob_matrix
|
||||
|
||||
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
|
||||
remaining = np.arange(num_options)
|
||||
for j in range(max_tokens):
|
||||
max_logprob = np.max(logprob_matrix[remaining, j])
|
||||
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
|
||||
if len(remaining) == 1:
|
||||
break
|
||||
return remaining
|
||||
|
||||
|
||||
greedy_token_selection = GreedyTokenSelection()
|
||||
|
||||
|
||||
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
|
||||
|
||||
@property
|
||||
def requires_unconditional_logprobs(self) -> bool:
|
||||
return True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
choices: List[str],
|
||||
normalized_prompt_logprobs: List[float],
|
||||
input_token_logprobs: List[List[Any]],
|
||||
output_token_logprobs: List[List[Any]],
|
||||
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
||||
) -> ChoicesDecision:
|
||||
"""Select the option with the highest average token logprob once normalized by
|
||||
the unconditional token logprobs.
|
||||
|
||||
The first unconditional token logprob is assumed to be None. If so, it is
|
||||
replaced with 0 for the purposes of normalization."""
|
||||
|
||||
if unconditional_token_logprobs is None:
|
||||
raise ValueError(
|
||||
"Unconditional token logprobs are required for this method."
|
||||
)
|
||||
|
||||
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
|
||||
input_token_logprobs, unconditional_token_logprobs
|
||||
)
|
||||
|
||||
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
|
||||
meta_info = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
"unconditional_token_logprobs": unconditional_token_logprobs,
|
||||
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
|
||||
}
|
||||
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
||||
|
||||
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
|
||||
normalized_unconditional_prompt_logprobs = []
|
||||
for inputs, unconditionals in zip(
|
||||
input_token_logprobs, unconditional_token_logprobs
|
||||
):
|
||||
inputs_logprobs = np.array([token[0] for token in inputs])
|
||||
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
|
||||
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
|
||||
normalized_unconditional_prompt_logprobs.append(
|
||||
float(np.mean(inputs_logprobs - unconditionals_logprobs))
|
||||
)
|
||||
return normalized_unconditional_prompt_logprobs
|
||||
|
||||
|
||||
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
|
||||
@@ -538,24 +538,17 @@ class StreamExecutor:
|
||||
self.stream_var_event[name].set()
|
||||
|
||||
def _execute_select(self, expr: SglSelect):
|
||||
(
|
||||
decision,
|
||||
normalized_prompt_logprobs,
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
) = self.backend.select(self, expr.choices, expr.temperature)
|
||||
choices_decision = self.backend.select(
|
||||
self, expr.choices, expr.temperature, expr.choices_method
|
||||
)
|
||||
if expr.name is not None:
|
||||
name = expr.name
|
||||
self.variables[name] = decision
|
||||
self.meta_info[name] = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
}
|
||||
self.variables[name] = choices_decision.decision
|
||||
self.meta_info[name] = choices_decision.meta_info
|
||||
self.variable_event[name].set()
|
||||
if self.stream_var_event:
|
||||
self.stream_var_event[name].set()
|
||||
self.text_ += decision
|
||||
self.text_ += choices_decision.decision
|
||||
|
||||
def _execute_variable(self, expr: SglVariable):
|
||||
src_executor = expr.source_stream_executor
|
||||
|
||||
@@ -6,6 +6,7 @@ import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.choices import ChoicesSamplingMethod
|
||||
|
||||
REGEX_INT = r"[-+]?[0-9]+"
|
||||
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
||||
@@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr):
|
||||
|
||||
|
||||
class SglSelect(SglExpr):
|
||||
def __init__(self, name: str, choices: List[str], temperature: float):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
choices_method: ChoicesSamplingMethod,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.choices = choices
|
||||
self.temperature = temperature
|
||||
self.choices_method = choices_method
|
||||
|
||||
def __repr__(self):
|
||||
return f"Select({self.name}, choices={self.choices})"
|
||||
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
|
||||
|
||||
|
||||
class SglFork(SglExpr):
|
||||
|
||||
Reference in New Issue
Block a user