Support n in OpenAI API completions (#3446)
Co-authored-by: Shan Yu <shanyu1@g.ucla.edu> Co-authored-by: Yineng Zhang <me@zhyncs.com> Co-authored-by: chuyue sun <chuyue@lmsys.us-northcentral1-a.compute.internal>
This commit is contained in:
73
examples/frontend_language/quick_start/openai_example_n.py
Normal file
73
examples/frontend_language/quick_start/openai_example_n.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
export OPENAI_API_KEY=sk-******
|
||||||
|
python3 openai_example_chat.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def multi_turn_question(s, question_1, question_2):
|
||||||
|
s += sgl.system("You are a helpful assistant.")
|
||||||
|
s += sgl.user(question_1)
|
||||||
|
s += sgl.assistant(sgl.gen("answer_1", max_tokens=1024, n=2))
|
||||||
|
s += sgl.user(question_2)
|
||||||
|
s += sgl.assistant(
|
||||||
|
sgl.gen(
|
||||||
|
"answer_2",
|
||||||
|
max_tokens=1024,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def single():
|
||||||
|
state = multi_turn_question.run(
|
||||||
|
question_1="What is the capital of the United States?",
|
||||||
|
question_2="List two local attractions.",
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in state.messages():
|
||||||
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
|
print("\n-- answer_1 --\n", state["answer_1"])
|
||||||
|
print("\n-- answer_2 --\n", state["answer_2"])
|
||||||
|
assert isinstance(state["answer_1"], list)
|
||||||
|
assert len(state["answer_1"]) == 2
|
||||||
|
assert isinstance(state["answer_2"], str)
|
||||||
|
|
||||||
|
|
||||||
|
def batch():
|
||||||
|
states = multi_turn_question.run_batch(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"question_1": "What is the capital of the United States?",
|
||||||
|
"question_2": "List two local attractions.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"question_1": "What is the capital of France?",
|
||||||
|
"question_2": "What is the population of this city?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for s in states:
|
||||||
|
print(s.messages())
|
||||||
|
print("\n-- answer_1 --\n", s["answer_1"])
|
||||||
|
print("\n-- answer_2 --\n", s["answer_2"])
|
||||||
|
assert isinstance(s["answer_1"], list)
|
||||||
|
assert len(s["answer_1"]) == 2
|
||||||
|
assert isinstance(s["answer_2"], str)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sgl.set_default_backend(sgl.OpenAI("o1"))
|
||||||
|
|
||||||
|
# Run a single request
|
||||||
|
print("\n========== single ==========\n")
|
||||||
|
single()
|
||||||
|
# Run a batch of requests
|
||||||
|
print("\n========== batch ==========\n")
|
||||||
|
batch()
|
||||||
@@ -75,6 +75,7 @@ def gen(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
min_tokens: Optional[int] = None,
|
min_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
@@ -115,6 +116,7 @@ def gen(
|
|||||||
name,
|
name,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
min_tokens,
|
min_tokens,
|
||||||
|
n,
|
||||||
stop,
|
stop,
|
||||||
stop_token_ids,
|
stop_token_ids,
|
||||||
temperature,
|
temperature,
|
||||||
@@ -137,6 +139,7 @@ def gen(
|
|||||||
def gen_int(
|
def gen_int(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
@@ -155,6 +158,7 @@ def gen_int(
|
|||||||
name,
|
name,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
None,
|
None,
|
||||||
|
n,
|
||||||
stop,
|
stop,
|
||||||
stop_token_ids,
|
stop_token_ids,
|
||||||
temperature,
|
temperature,
|
||||||
@@ -176,6 +180,7 @@ def gen_int(
|
|||||||
def gen_string(
|
def gen_string(
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
@@ -194,6 +199,7 @@ def gen_string(
|
|||||||
name,
|
name,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
None,
|
None,
|
||||||
|
n,
|
||||||
stop,
|
stop,
|
||||||
stop_token_ids,
|
stop_token_ids,
|
||||||
temperature,
|
temperature,
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ class OpenAI(BaseBackend):
|
|||||||
kwargs.pop("max_tokens", None)
|
kwargs.pop("max_tokens", None)
|
||||||
else:
|
else:
|
||||||
kwargs.pop("max_completion_tokens", None)
|
kwargs.pop("max_completion_tokens", None)
|
||||||
|
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
token_usage=self.token_usage,
|
token_usage=self.token_usage,
|
||||||
@@ -173,13 +174,13 @@ class OpenAI(BaseBackend):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
# Keep the returned list (or string) as is.
|
||||||
elif sampling_params.dtype in [str, "str", "string"]:
|
elif sampling_params.dtype in [str, "str", "string"]:
|
||||||
assert (
|
assert (
|
||||||
not self.is_chat_model
|
not self.is_chat_model
|
||||||
), "constrained type not supported on chat model"
|
), "constrained type not supported on chat model"
|
||||||
kwargs = sampling_params.to_openai_kwargs()
|
kwargs = sampling_params.to_openai_kwargs()
|
||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
|
|
||||||
comp = openai_completion(
|
comp = openai_completion(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
token_usage=self.token_usage,
|
token_usage=self.token_usage,
|
||||||
@@ -189,7 +190,11 @@ class OpenAI(BaseBackend):
|
|||||||
stop='"',
|
stop='"',
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
comp = '"' + comp + '"'
|
# Wrap each element in quotes if we have a list.
|
||||||
|
if isinstance(comp, list):
|
||||||
|
comp = ['"' + x + '"' for x in comp]
|
||||||
|
else:
|
||||||
|
comp = '"' + comp + '"'
|
||||||
elif sampling_params.dtype in [int, "int"]:
|
elif sampling_params.dtype in [int, "int"]:
|
||||||
assert (
|
assert (
|
||||||
not self.is_chat_model
|
not self.is_chat_model
|
||||||
@@ -206,6 +211,7 @@ class OpenAI(BaseBackend):
|
|||||||
stop=[" "],
|
stop=[" "],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
# Leave as a list if that's what is returned.
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||||
|
|
||||||
@@ -254,7 +260,9 @@ class OpenAI(BaseBackend):
|
|||||||
prompt=s.messages_,
|
prompt=s.messages_,
|
||||||
**self.spec_kwargs,
|
**self.spec_kwargs,
|
||||||
)
|
)
|
||||||
if self.spec_pattern_match(comp):
|
# Use a string for pattern matching.
|
||||||
|
comp_for_match = comp[0] if isinstance(comp, list) else comp
|
||||||
|
if self.spec_pattern_match(comp_for_match):
|
||||||
break
|
break
|
||||||
|
|
||||||
for term in self.spec_format:
|
for term in self.spec_format:
|
||||||
@@ -370,7 +378,7 @@ class OpenAI(BaseBackend):
|
|||||||
|
|
||||||
def openai_completion(
|
def openai_completion(
|
||||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||||
):
|
) -> Union[str, List[str]]:
|
||||||
# if "ebnf" is in kwargs, warn and remove
|
# if "ebnf" is in kwargs, warn and remove
|
||||||
if "ebnf" in kwargs:
|
if "ebnf" in kwargs:
|
||||||
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.")
|
||||||
@@ -382,13 +390,18 @@ def openai_completion(
|
|||||||
if "stop" in kwargs and kwargs["stop"] is None:
|
if "stop" in kwargs and kwargs["stop"] is None:
|
||||||
kwargs.pop("stop")
|
kwargs.pop("stop")
|
||||||
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
||||||
comp = ret.choices[0].message.content
|
if len(ret.choices) == 1:
|
||||||
|
comp = ret.choices[0].message.content
|
||||||
|
else:
|
||||||
|
comp = [c.message.content for c in ret.choices]
|
||||||
else:
|
else:
|
||||||
ret = client.completions.create(prompt=prompt, **kwargs)
|
ret = client.completions.create(prompt=prompt, **kwargs)
|
||||||
if isinstance(prompt, (list, tuple)):
|
if isinstance(prompt, (list, tuple)):
|
||||||
comp = [c.text for c in ret.choices]
|
comp = [c.text for c in ret.choices]
|
||||||
else:
|
else:
|
||||||
comp = ret.choices[0].text
|
comp = ret.choices[0].text
|
||||||
|
if len(ret.choices) > 1:
|
||||||
|
comp = [c.text for c in ret.choices]
|
||||||
|
|
||||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||||
|
|||||||
@@ -566,13 +566,13 @@ class StreamExecutor:
|
|||||||
def _execute_gen(self, expr: SglGen):
|
def _execute_gen(self, expr: SglGen):
|
||||||
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
||||||
name = expr.name
|
name = expr.name
|
||||||
|
|
||||||
if not self.stream:
|
if not self.stream:
|
||||||
if self.num_api_spec_tokens is None:
|
if self.num_api_spec_tokens is None:
|
||||||
comp, meta_info = self.backend.generate(
|
comp, meta_info = self.backend.generate(
|
||||||
self,
|
self,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.backend.is_chat_model:
|
if self.backend.is_chat_model:
|
||||||
# Speculative execution on models with only chat interface.
|
# Speculative execution on models with only chat interface.
|
||||||
@@ -587,8 +587,11 @@ class StreamExecutor:
|
|||||||
|
|
||||||
else: # Speculative execution on models with completion interface
|
else: # Speculative execution on models with completion interface
|
||||||
comp, meta_info = self._spec_gen(sampling_params)
|
comp, meta_info = self._spec_gen(sampling_params)
|
||||||
|
if isinstance(comp, list):
|
||||||
self.text_ += comp
|
self.text_ += comp[0]
|
||||||
|
else:
|
||||||
|
assert isinstance(comp, str)
|
||||||
|
self.text_ += comp
|
||||||
|
|
||||||
self.variables[name] = comp
|
self.variables[name] = comp
|
||||||
self.meta_info[name] = meta_info
|
self.meta_info[name] = meta_info
|
||||||
@@ -747,6 +750,7 @@ class StreamExecutor:
|
|||||||
for item in [
|
for item in [
|
||||||
"max_new_tokens",
|
"max_new_tokens",
|
||||||
"min_new_tokens",
|
"min_new_tokens",
|
||||||
|
"n",
|
||||||
"stop",
|
"stop",
|
||||||
"stop_token_ids",
|
"stop_token_ids",
|
||||||
"temperature",
|
"temperature",
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
|||||||
class SglSamplingParams:
|
class SglSamplingParams:
|
||||||
max_new_tokens: int = 128
|
max_new_tokens: int = 128
|
||||||
min_new_tokens: int = 0
|
min_new_tokens: int = 0
|
||||||
|
n: int = 1
|
||||||
stop: Union[str, List[str]] = ()
|
stop: Union[str, List[str]] = ()
|
||||||
stop_token_ids: Optional[List[int]] = ()
|
stop_token_ids: Optional[List[int]] = ()
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
@@ -41,6 +42,7 @@ class SglSamplingParams:
|
|||||||
return SglSamplingParams(
|
return SglSamplingParams(
|
||||||
self.max_new_tokens,
|
self.max_new_tokens,
|
||||||
self.min_new_tokens,
|
self.min_new_tokens,
|
||||||
|
self.n,
|
||||||
self.stop,
|
self.stop,
|
||||||
self.stop_token_ids,
|
self.stop_token_ids,
|
||||||
self.temperature,
|
self.temperature,
|
||||||
@@ -64,6 +66,7 @@ class SglSamplingParams:
|
|||||||
return {
|
return {
|
||||||
"max_tokens": self.max_new_tokens,
|
"max_tokens": self.max_new_tokens,
|
||||||
"max_completion_tokens": self.max_new_tokens,
|
"max_completion_tokens": self.max_new_tokens,
|
||||||
|
"n": self.n,
|
||||||
"stop": self.stop or None,
|
"stop": self.stop or None,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
@@ -117,6 +120,7 @@ class SglSamplingParams:
|
|||||||
return {
|
return {
|
||||||
"max_new_tokens": self.max_new_tokens,
|
"max_new_tokens": self.max_new_tokens,
|
||||||
"min_new_tokens": self.min_new_tokens,
|
"min_new_tokens": self.min_new_tokens,
|
||||||
|
"n": self.n,
|
||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
"stop_token_ids": self.stop_token_ids,
|
"stop_token_ids": self.stop_token_ids,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
@@ -154,6 +158,7 @@ class SglFunction:
|
|||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
|
n: int = 1,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
@@ -182,6 +187,7 @@ class SglFunction:
|
|||||||
|
|
||||||
default_sampling_para = SglSamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
n=n,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -212,6 +218,7 @@ class SglFunction:
|
|||||||
batch_kwargs,
|
batch_kwargs,
|
||||||
*,
|
*,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
|
n: int = 1,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
@@ -257,6 +264,7 @@ class SglFunction:
|
|||||||
|
|
||||||
default_sampling_para = SglSamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
|
n=n,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -440,6 +448,7 @@ class SglGen(SglExpr):
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: Optional[int] = None,
|
||||||
min_new_tokens: Optional[int] = None,
|
min_new_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
@@ -463,6 +472,7 @@ class SglGen(SglExpr):
|
|||||||
self.sampling_params = SglSamplingParams(
|
self.sampling_params = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
min_new_tokens=min_new_tokens,
|
min_new_tokens=min_new_tokens,
|
||||||
|
n=n,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|||||||
Reference in New Issue
Block a user