From 94e0115186d91fe00910a17b02da2a62de6b2d45 Mon Sep 17 00:00:00 2001 From: Aidan Cooper <30752032+AidanCooper@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:27:49 +0100 Subject: [PATCH] Feat: add alternative choices selection methods (#835) --- docs/en/choices_methods.md | 77 ++++++++ python/sglang/__init__.py | 8 + python/sglang/api.py | 12 +- python/sglang/lang/backend/base_backend.py | 4 +- python/sglang/lang/backend/openai.py | 11 +- .../sglang/lang/backend/runtime_endpoint.py | 71 +++++--- python/sglang/lang/choices.py | 164 ++++++++++++++++++ python/sglang/lang/interpreter.py | 19 +- python/sglang/lang/ir.py | 13 +- test/lang/test_choices.py | 95 ++++++++++ 10 files changed, 426 insertions(+), 48 deletions(-) create mode 100644 docs/en/choices_methods.md create mode 100644 python/sglang/lang/choices.py create mode 100644 test/lang/test_choices.py diff --git a/docs/en/choices_methods.md b/docs/en/choices_methods.md new file mode 100644 index 000000000..e0f3ed313 --- /dev/null +++ b/docs/en/choices_methods.md @@ -0,0 +1,77 @@ +# Choices Methods in SGLang +This doc describes the choices methods supported by SGLang. + +The optional `choices_method` arg determines how options supplied to SGLang's `choices` primitive are selected. Only the `RuntimeEndpoint` backend supports the `choices_method` arg. Other backends, such as `OpenAI`, have bespoke selection implementations due to API limitations. + +## Methods + +### Token Length Normalized + +Token length normalized is the default SGLang choices method. It selects the option with the highest average logprob across all of its tokens. + +Usage example (alternatively, simply omit the `choices_method` arg): +```python +@sgl.function +def example(s): + s += sgl.user("What is the capital of France?") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["London", "Paris", "Berlin"], + choices_method=sgl.token_length_normalized, + ) + ) +``` + + +This can perform poorly if an option contains many tokens, where its later tokens are predicted with high confidence based on its earlier tokens. For instance, even strong models will fail the above example if the specified options are `["Paris", "Antidisestablishmentarianism"]`. + +### Greedy Token Selection + +Greedy token selection simply selects the option with the highest logprob for its initial token. For overlapping options where one option is a subset of a longer option, the logprobs of the shorter option are extended using its average logprob for comparison against the longer option. + +Usage example: +```python +@sgl.function +def example(s): + s += sgl.user("What is the capital of France?") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["London", "Paris", "Berlin"], + choices_method=sgl.greedy_token_selection, + ) + ) +``` + +This can perform poorly if an option misleads the model down a bad path based on an attractive initial token. For instance, greedy selection will result in an incorrect response for this example: +```python +@sgl.function +def us_president_example(s): + s += sgl.user("Name a US president.") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["Donald Duck", "Millard Fillmore"], + choices_method=sgl.greedy_token_selection, + ) + ) +``` + +### Unconditional Likelihood Normalized + +Unconditional likelihood normalized selects the option with the highest average token logprob once normalized by the unconditional token logprobs, as described in [this EleutherAI blogpost](https://blog.eleuther.ai/multiple-choice-normalization/). This method incurs an additional LLM call to obtain the unconditional likelihoods. + +Usage example: +```python +@sgl.function +def example(s): + s += sgl.user("What is the capital of France?") + s += sgl.assistant( + sgl.gen( + "answer", + choices=["London", "Paris", "Berlin"], + choices_method=sgl.unconditional_likelihood_normalized, + ) + ) +``` \ No newline at end of file diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index f4eec131e..71d7bfecc 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -22,6 +22,11 @@ from sglang.api import ( user_end, video, ) +from sglang.lang.choices import ( + greedy_token_selection, + token_length_normalized, + unconditional_likelihood_normalized, +) # SGLang DSL APIs __all__ = [ @@ -45,6 +50,9 @@ __all__ = [ "user_begin", "user_end", "video", + "greedy_token_selection", + "token_length_normalized", + "unconditional_likelihood_normalized", ] # Global Configurations diff --git a/python/sglang/api.py b/python/sglang/api.py index e6b6715a8..5a177c36b 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.choices import ChoicesSamplingMethod, token_length_normalized from sglang.lang.ir import ( SglExpr, SglExprList, @@ -73,12 +74,18 @@ def gen( return_text_in_logprobs: Optional[bool] = None, dtype: Optional[type] = None, choices: Optional[List[str]] = None, + choices_method: Optional[ChoicesSamplingMethod] = None, regex: Optional[str] = None, ): """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" if choices: - return SglSelect(name, choices, 0.0 if temperature is None else temperature) + return SglSelect( + name, + choices, + 0.0 if temperature is None else temperature, + token_length_normalized if choices_method is None else choices_method, + ) # check regex is valid if regex is not None: @@ -186,9 +193,10 @@ def select( name: Optional[str] = None, choices: Optional[List[str]] = None, temperature: float = 0.0, + choices_method: ChoicesSamplingMethod = token_length_normalized, ): assert choices is not None - return SglSelect(name, choices, temperature) + return SglSelect(name, choices, temperature, choices_method) def _role_common(name: str, expr: Optional[SglExpr] = None): diff --git a/python/sglang/lang/backend/base_backend.py b/python/sglang/lang/backend/base_backend.py index cb504f51b..185f2e297 100644 --- a/python/sglang/lang/backend/base_backend.py +++ b/python/sglang/lang/backend/base_backend.py @@ -1,6 +1,7 @@ from typing import Callable, List, Optional, Union from sglang.lang.chat_template import get_chat_template +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams @@ -64,7 +65,8 @@ class BaseBackend: s: StreamExecutor, choices: List[str], temperature: float, - ): + choices_method: Optional[ChoicesSamplingMethod] = None, + ) -> ChoicesDecision: raise NotImplementedError() def concatenate_and_append(self, src_rids: List[str], dst_rid: str): diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index 48dcc080e..6fa93d9b2 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -8,6 +8,7 @@ import numpy as np from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams @@ -296,7 +297,9 @@ class OpenAI(BaseBackend): s: StreamExecutor, choices: List[str], temperature: float, - ): + choices_method: ChoicesSamplingMethod, + ) -> ChoicesDecision: + """Note: `choices_method` is not used by the OpenAI backend.""" if self.is_chat_model: raise NotImplementedError( "select/choices is not supported for chat models. " @@ -354,8 +357,10 @@ class OpenAI(BaseBackend): prompt_tokens.append(ret_token) - decision = choices[np.argmax(scores)] - return decision, scores, None, None + return ChoicesDecision( + decision=choices[np.argmax(scores)], + meta_info={"scores": scores}, + ) def openai_completion( diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 49df598de..7f0db5b35 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -1,17 +1,21 @@ import json from typing import List, Optional -import numpy as np - from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.choices import ( + ChoicesDecision, + ChoicesSamplingMethod, + token_length_normalized, +) from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import SglSamplingParams from sglang.utils import http_request class RuntimeEndpoint(BaseBackend): + def __init__( self, base_url: str, @@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend): s: StreamExecutor, choices: List[str], temperature: float, - ): + choices_method: ChoicesSamplingMethod, + ) -> ChoicesDecision: assert temperature <= 1e-5 # Cache common prefix data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} - self._add_images(s, data) - res = http_request( - self.base_url + "/generate", - json=data, - api_key=self.api_key, - verify=self.verify, - ) - self._assert_success(res) - prompt_len = res.json()["meta_info"]["prompt_tokens"] + obj = self._generate_http_request(s, data) + prompt_len = obj["meta_info"]["prompt_tokens"] # Compute logprob data = { @@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend): "return_logprob": True, "logprob_start_len": max(prompt_len - 2, 0), } - self._add_images(s, data) - res = http_request( - self.base_url + "/generate", - json=data, - api_key=self.api_key, - verify=self.verify, - ) - self._assert_success(res) - obj = res.json() + obj = self._generate_http_request(s, data) + normalized_prompt_logprobs = [ r["meta_info"]["normalized_prompt_logprob"] for r in obj ] - decision = choices[np.argmax(normalized_prompt_logprobs)] input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] - return ( - decision, - normalized_prompt_logprobs, - input_token_logprobs, - output_token_logprobs, + # Compute unconditional logprobs if required + if choices_method.requires_unconditional_logprobs: + input_ids = [[el[1] for el in subl] for subl in input_token_logprobs] + data = { + "input_ids": input_ids, + "sampling_params": {"max_new_tokens": 0}, + "return_logprob": True, + } + obj = self._generate_http_request(s, data) + unconditional_token_logprobs = [ + r["meta_info"]["input_token_logprobs"] for r in obj + ] + else: + unconditional_token_logprobs = None + + return choices_method( + choices=choices, + normalized_prompt_logprobs=normalized_prompt_logprobs, + input_token_logprobs=input_token_logprobs, + output_token_logprobs=output_token_logprobs, + unconditional_token_logprobs=unconditional_token_logprobs, ) def concatenate_and_append(self, src_rids: List[str], dst_rid: str): @@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend): ) self._assert_success(res) + def _generate_http_request(self, s: StreamExecutor, data): + self._add_images(s, data) + res = http_request( + self.base_url + "/generate", + json=data, + api_key=self.api_key, + verify=self.verify, + ) + self._assert_success(res) + return res.json() + def _add_images(self, s: StreamExecutor, data): if s.images_: assert len(s.images_) == 1, "Only support one image." diff --git a/python/sglang/lang/choices.py b/python/sglang/lang/choices.py new file mode 100644 index 000000000..f10fbff97 --- /dev/null +++ b/python/sglang/lang/choices.py @@ -0,0 +1,164 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import numpy as np + + +@dataclass +class ChoicesDecision: + decision: str + meta_info: Dict[str, Any] | None = None + + +class ChoicesSamplingMethod(ABC): + + @property + def requires_unconditional_logprobs(self) -> bool: + return False + + @abstractmethod + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: ... + + +class TokenLengthNormalized(ChoicesSamplingMethod): + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option with the highest token length normalized prompt logprob.""" + best_choice = choices[np.argmax(normalized_prompt_logprobs)] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + +token_length_normalized = TokenLengthNormalized() + + +class GreedyTokenSelection(ChoicesSamplingMethod): + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option based on greedy logprob selection. For overlapping options + where one option is a subset of a longer option, extend the shorter option using + its average logprob for comparison against the longer option.""" + + num_options = len(choices) + max_tokens = max(len(option) for option in input_token_logprobs) + logprob_matrix = self._build_logprob_matrix( + input_token_logprobs, max_tokens, num_options + ) + remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens) + + best_choice = choices[remaining[0]] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + "greedy_logprob_matrix": logprob_matrix.tolist(), + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options): + logprob_matrix = np.zeros((num_options, max_tokens)) + for i, option in enumerate(input_token_logprobs): + actual_logprobs = [token[0] for token in option] + avg_logprob = np.mean(actual_logprobs) + logprob_matrix[i, : len(option)] = actual_logprobs + if len(option) < max_tokens: + logprob_matrix[i, len(option) :] = avg_logprob + return logprob_matrix + + def _greedy_selection(self, logprob_matrix, num_options, max_tokens): + remaining = np.arange(num_options) + for j in range(max_tokens): + max_logprob = np.max(logprob_matrix[remaining, j]) + remaining = remaining[logprob_matrix[remaining, j] == max_logprob] + if len(remaining) == 1: + break + return remaining + + +greedy_token_selection = GreedyTokenSelection() + + +class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod): + + @property + def requires_unconditional_logprobs(self) -> bool: + return True + + def __call__( + self, + *, + choices: List[str], + normalized_prompt_logprobs: List[float], + input_token_logprobs: List[List[Any]], + output_token_logprobs: List[List[Any]], + unconditional_token_logprobs: Optional[List[List[Any]]] = None, + ) -> ChoicesDecision: + """Select the option with the highest average token logprob once normalized by + the unconditional token logprobs. + + The first unconditional token logprob is assumed to be None. If so, it is + replaced with 0 for the purposes of normalization.""" + + if unconditional_token_logprobs is None: + raise ValueError( + "Unconditional token logprobs are required for this method." + ) + + normalized_unconditional_prompt_logprobs = self._normalize_logprobs( + input_token_logprobs, unconditional_token_logprobs + ) + + best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)] + meta_info = { + "normalized_prompt_logprobs": normalized_prompt_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, + "unconditional_token_logprobs": unconditional_token_logprobs, + "normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs, + } + return ChoicesDecision(decision=best_choice, meta_info=meta_info) + + def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs): + normalized_unconditional_prompt_logprobs = [] + for inputs, unconditionals in zip( + input_token_logprobs, unconditional_token_logprobs + ): + inputs_logprobs = np.array([token[0] for token in inputs]) + unconditionals_logprobs = np.array([token[0] for token in unconditionals]) + unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0 + normalized_unconditional_prompt_logprobs.append( + float(np.mean(inputs_logprobs - unconditionals_logprobs)) + ) + return normalized_unconditional_prompt_logprobs + + +unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized() diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 2096e49e8..cf53fac30 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -538,24 +538,17 @@ class StreamExecutor: self.stream_var_event[name].set() def _execute_select(self, expr: SglSelect): - ( - decision, - normalized_prompt_logprobs, - input_token_logprobs, - output_token_logprobs, - ) = self.backend.select(self, expr.choices, expr.temperature) + choices_decision = self.backend.select( + self, expr.choices, expr.temperature, expr.choices_method + ) if expr.name is not None: name = expr.name - self.variables[name] = decision - self.meta_info[name] = { - "normalized_prompt_logprobs": normalized_prompt_logprobs, - "input_token_logprobs": input_token_logprobs, - "output_token_logprobs": output_token_logprobs, - } + self.variables[name] = choices_decision.decision + self.meta_info[name] = choices_decision.meta_info self.variable_event[name].set() if self.stream_var_event: self.stream_var_event[name].set() - self.text_ += decision + self.text_ += choices_decision.decision def _execute_variable(self, expr: SglVariable): src_executor = expr.source_stream_executor diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 2ee167f86..d902497c7 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -6,6 +6,7 @@ import warnings from typing import List, Optional, Union from sglang.global_config import global_config +from sglang.lang.choices import ChoicesSamplingMethod REGEX_INT = r"[-+]?[0-9]+" REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+" @@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr): class SglSelect(SglExpr): - def __init__(self, name: str, choices: List[str], temperature: float): + + def __init__( + self, + name: str, + choices: List[str], + temperature: float, + choices_method: ChoicesSamplingMethod, + ): super().__init__() self.name = name self.choices = choices self.temperature = temperature + self.choices_method = choices_method def __repr__(self): - return f"Select({self.name}, choices={self.choices})" + return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})" class SglFork(SglExpr): diff --git a/test/lang/test_choices.py b/test/lang/test_choices.py new file mode 100644 index 000000000..da25e9e49 --- /dev/null +++ b/test/lang/test_choices.py @@ -0,0 +1,95 @@ +import unittest + +import numpy as np + +from sglang.lang.choices import ( + greedy_token_selection, + token_length_normalized, + unconditional_likelihood_normalized, +) + +MOCK_CHOICES_INPUT_DATA = { + "choices": [ + "organ", # ["organ"] + "organism", # ["organ", "ism"] + "antidisestablishmentarianism", # ["ant", "id", "is", "est", "ablish", "ment", "arian", "ism"] + ], + "normalized_prompt_logprobs": [-0.1, -0.2, -0.05], + "input_token_logprobs": [ + [[-0.1, 1, None]], + [[-0.1, 1, None], [-0.3, 2, None]], + [ + [-0.4, 3, None], + [-0.25, 4, None], + [-0.1, 5, None], + [-0.01, 6, None], + [-0.01, 7, None], + [-0.01, 8, None], + [-0.01, 9, None], + [-0.01, 2, None], + ], + ], + "output_token_logprobs": [ + [[-0.1, 10, None]], + [[-0.1, 10, None]], + [[-0.1, 10, None]], + ], + "unconditional_token_logprobs": [ + [[None, 1, None]], + [[None, 1, None], [-1.4, 2, None]], + [ + [None, 3, None], + [-0.25, 4, None], + [-0.1, 5, None], + [-0.01, 6, None], + [-0.01, 7, None], + [-0.01, 8, None], + [-0.01, 9, None], + [-0.01, 2, None], + ], + ], +} + + +class TestChoices(unittest.TestCase): + + def test_token_length_normalized(self): + """Confirm 'antidisestablishmentarianism' is selected due to high confidences for + its later tokens resulting in highest token length normalized prompt logprob.""" + decision = token_length_normalized(**MOCK_CHOICES_INPUT_DATA) + assert decision.decision == "antidisestablishmentarianism" + + def test_greedy_token_selection(self): + """Confirm 'organ' is selected due it having the joint highest initial token + logprob, and a higher average logprob than organism's second token.""" + decision = greedy_token_selection(**MOCK_CHOICES_INPUT_DATA) + assert decision.decision == "organ" + assert np.allclose( + decision.meta_info["greedy_logprob_matrix"], + [ + [-0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1, -0.1], + [-0.1, -0.3, -0.2, -0.2, -0.2, -0.2, -0.2, -0.2], + [-0.4, -0.25, -0.1, -0.01, -0.01, -0.01, -0.01, -0.01], + ], + atol=0.01, + ) + + def test_unconditional_likelihood_normalized(self): + """Confirm 'organism' is selected due to it having the highest average token logprob + once normalized by the unconditional token logprobs.""" + decision = unconditional_likelihood_normalized(**MOCK_CHOICES_INPUT_DATA) + assert decision.decision == "organism" + assert np.allclose( + decision.meta_info["normalized_unconditional_prompt_logprobs"], + [-0.1, 0.5, -0.05], + atol=0.01, + ) + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # t = TestChoices() + # t.test_token_length_normalized() + # t.test_greedy_token_selection() + # t.test_unconditional_likelihood_normalized()