Support n in OpenAI API completions (#3446)
Co-authored-by: Shan Yu <shanyu1@g.ucla.edu> Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: chuyue sun <chuyue@lmsys.us-northcentral1-a.compute.internal>
This commit is contained in:
@@ -75,6 +75,7 @@ def gen(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -115,6 +116,7 @@ def gen(
|
||||
name,
|
||||
max_tokens,
|
||||
min_tokens,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
@@ -137,6 +139,7 @@ def gen(
|
||||
def gen_int(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -155,6 +158,7 @@ def gen_int(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
@@ -176,6 +180,7 @@ def gen_int(
|
||||
def gen_string(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -194,6 +199,7 @@ def gen_string(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
|
||||
@@ -165,6 +165,7 @@ class OpenAI(BaseBackend):
|
||||
kwargs.pop("max_tokens", None)
|
||||
else:
|
||||
kwargs.pop("max_completion_tokens", None)
|
||||
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
@@ -173,13 +174,13 @@ class OpenAI(BaseBackend):
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
# Keep the returned list (or string) as is.
|
||||
elif sampling_params.dtype in [str, "str", "string"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
@@ -189,7 +190,11 @@ class OpenAI(BaseBackend):
|
||||
stop='"',
|
||||
**kwargs,
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
# Wrap each element in quotes if we have a list.
|
||||
if isinstance(comp, list):
|
||||
comp = ['"' + x + '"' for x in comp]
|
||||
else:
|
||||
comp = '"' + comp + '"'
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
@@ -206,6 +211,7 @@ class OpenAI(BaseBackend):
|
||||
stop=[" "],
|
||||
**kwargs,
|
||||
)
|
||||
# Leave as a list if that's what is returned.
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||
|
||||
@@ -254,7 +260,9 @@ class OpenAI(BaseBackend):
|
||||
prompt=s.messages_,
|
||||
**self.spec_kwargs,
|
||||
)
|
||||
if self.spec_pattern_match(comp):
|
||||
# Use a string for pattern matching.
|
||||
comp_for_match = comp[0] if isinstance(comp, list) else comp
|
||||
if self.spec_pattern_match(comp_for_match):
|
||||
break
|
||||
|
||||
for term in self.spec_format:
|
||||
@@ -370,7 +378,7 @@ class OpenAI(BaseBackend):
|
||||
|
||||
def openai_completion(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
) -> Union[str, List[str]]:
|
||||
# if "ebnf" is in kwargs, warn and remove
|
||||
if "ebnf" in kwargs:
|
||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||
@@ -382,13 +390,18 @@ def openai_completion(
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
||||
comp = ret.choices[0].message.content
|
||||
if len(ret.choices) == 1:
|
||||
comp = ret.choices[0].message.content
|
||||
else:
|
||||
comp = [c.message.content for c in ret.choices]
|
||||
else:
|
||||
ret = client.completions.create(prompt=prompt, **kwargs)
|
||||
if isinstance(prompt, (list, tuple)):
|
||||
comp = [c.text for c in ret.choices]
|
||||
else:
|
||||
comp = ret.choices[0].text
|
||||
if len(ret.choices) > 1:
|
||||
comp = [c.text for c in ret.choices]
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
|
||||
@@ -566,13 +566,13 @@ class StreamExecutor:
|
||||
def _execute_gen(self, expr: SglGen):
|
||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||
name = expr.name
|
||||
|
||||
if not self.stream:
|
||||
if self.num_api_spec_tokens is None:
|
||||
comp, meta_info = self.backend.generate(
|
||||
self,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
else:
|
||||
if self.backend.is_chat_model:
|
||||
# Speculative execution on models with only chat interface.
|
||||
@@ -587,8 +587,11 @@ class StreamExecutor:
|
||||
|
||||
else: # Speculative execution on models with completion interface
|
||||
comp, meta_info = self._spec_gen(sampling_params)
|
||||
|
||||
self.text_ += comp
|
||||
if isinstance(comp, list):
|
||||
self.text_ += comp[0]
|
||||
else:
|
||||
assert isinstance(comp, str)
|
||||
self.text_ += comp
|
||||
|
||||
self.variables[name] = comp
|
||||
self.meta_info[name] = meta_info
|
||||
@@ -747,6 +750,7 @@ class StreamExecutor:
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"n",
|
||||
"stop",
|
||||
"stop_token_ids",
|
||||
"temperature",
|
||||
|
||||
@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 128
|
||||
min_new_tokens: int = 0
|
||||
n: int = 1
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
temperature: float = 1.0
|
||||
@@ -41,6 +42,7 @@ class SglSamplingParams:
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.min_new_tokens,
|
||||
self.n,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.temperature,
|
||||
@@ -64,6 +66,7 @@ class SglSamplingParams:
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"max_completion_tokens": self.max_new_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
@@ -117,6 +120,7 @@ class SglSamplingParams:
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"min_new_tokens": self.min_new_tokens,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"temperature": self.temperature,
|
||||
@@ -154,6 +158,7 @@ class SglFunction:
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 128,
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
@@ -182,6 +187,7 @@ class SglFunction:
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
@@ -212,6 +218,7 @@ class SglFunction:
|
||||
batch_kwargs,
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: float = 1.0,
|
||||
@@ -257,6 +264,7 @@ class SglFunction:
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
@@ -440,6 +448,7 @@ class SglGen(SglExpr):
|
||||
name: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -463,6 +472,7 @@ class SglGen(SglExpr):
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
|
||||
Reference in New Issue
Block a user