diff --git a/examples/frontend_language/quick_start/openai_example_n.py b/examples/frontend_language/quick_start/openai_example_n.py new file mode 100644 index 000000000..06f533003 --- /dev/null +++ b/examples/frontend_language/quick_start/openai_example_n.py @@ -0,0 +1,73 @@ +""" +Usage: +export OPENAI_API_KEY=sk-****** +python3 openai_example_chat.py +""" + +import json + +import sglang as sgl + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=1024, n=2)) + s += sgl.user(question_2) + s += sgl.assistant( + sgl.gen( + "answer_2", + max_tokens=1024, + ) + ) + + +def single(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + print("\n-- answer_2 --\n", state["answer_2"]) + assert isinstance(state["answer_1"], list) + assert len(state["answer_1"]) == 2 + assert isinstance(state["answer_2"], str) + + +def batch(): + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) + + for s in states: + print(s.messages()) + print("\n-- answer_1 --\n", s["answer_1"]) + print("\n-- answer_2 --\n", s["answer_2"]) + assert isinstance(s["answer_1"], list) + assert len(s["answer_1"]) == 2 + assert isinstance(s["answer_2"], str) + + +if __name__ == "__main__": + sgl.set_default_backend(sgl.OpenAI("o1")) + + # Run a single request + print("\n========== single ==========\n") + single() + # Run a batch of requests + print("\n========== batch ==========\n") + batch() diff --git a/python/sglang/api.py b/python/sglang/api.py index 2bd39d5ee..50319c451 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -75,6 +75,7 @@ def gen( name: Optional[str] = None, max_tokens: Optional[int] = None, min_tokens: Optional[int] = None, + n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -115,6 +116,7 @@ def gen( name, max_tokens, min_tokens, + n, stop, stop_token_ids, temperature, @@ -137,6 +139,7 @@ def gen( def gen_int( name: Optional[str] = None, max_tokens: Optional[int] = None, + n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -155,6 +158,7 @@ def gen_int( name, max_tokens, None, + n, stop, stop_token_ids, temperature, @@ -176,6 +180,7 @@ def gen_int( def gen_string( name: Optional[str] = None, max_tokens: Optional[int] = None, + n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -194,6 +199,7 @@ def gen_string( name, max_tokens, None, + n, stop, stop_token_ids, temperature, diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index 147437622..c6e31f73d 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -165,6 +165,7 @@ class OpenAI(BaseBackend): kwargs.pop("max_tokens", None) else: kwargs.pop("max_completion_tokens", None) + comp = openai_completion( client=self.client, token_usage=self.token_usage, @@ -173,13 +174,13 @@ class OpenAI(BaseBackend): prompt=prompt, **kwargs, ) + # Keep the returned list (or string) as is. elif sampling_params.dtype in [str, "str", "string"]: assert ( not self.is_chat_model ), "constrained type not supported on chat model" kwargs = sampling_params.to_openai_kwargs() kwargs.pop("stop") - comp = openai_completion( client=self.client, token_usage=self.token_usage, @@ -189,7 +190,11 @@ class OpenAI(BaseBackend): stop='"', **kwargs, ) - comp = '"' + comp + '"' + # Wrap each element in quotes if we have a list. + if isinstance(comp, list): + comp = ['"' + x + '"' for x in comp] + else: + comp = '"' + comp + '"' elif sampling_params.dtype in [int, "int"]: assert ( not self.is_chat_model @@ -206,6 +211,7 @@ class OpenAI(BaseBackend): stop=[" "], **kwargs, ) + # Leave as a list if that's what is returned. else: raise ValueError(f"Unknown dtype: {sampling_params.dtype}") @@ -254,7 +260,9 @@ class OpenAI(BaseBackend): prompt=s.messages_, **self.spec_kwargs, ) - if self.spec_pattern_match(comp): + # Use a string for pattern matching. + comp_for_match = comp[0] if isinstance(comp, list) else comp + if self.spec_pattern_match(comp_for_match): break for term in self.spec_format: @@ -370,7 +378,7 @@ class OpenAI(BaseBackend): def openai_completion( client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs -): +) -> Union[str, List[str]]: # if "ebnf" is in kwargs, warn and remove if "ebnf" in kwargs: warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") @@ -382,13 +390,18 @@ def openai_completion( if "stop" in kwargs and kwargs["stop"] is None: kwargs.pop("stop") ret = client.chat.completions.create(messages=prompt, **kwargs) - comp = ret.choices[0].message.content + if len(ret.choices) == 1: + comp = ret.choices[0].message.content + else: + comp = [c.message.content for c in ret.choices] else: ret = client.completions.create(prompt=prompt, **kwargs) if isinstance(prompt, (list, tuple)): comp = [c.text for c in ret.choices] else: comp = ret.choices[0].text + if len(ret.choices) > 1: + comp = [c.text for c in ret.choices] token_usage.prompt_tokens += ret.usage.prompt_tokens token_usage.completion_tokens += ret.usage.completion_tokens diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 4c294781c..1ac91642d 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -566,13 +566,13 @@ class StreamExecutor: def _execute_gen(self, expr: SglGen): sampling_params = self._resolve_sampling_params(expr.sampling_params) name = expr.name - if not self.stream: if self.num_api_spec_tokens is None: comp, meta_info = self.backend.generate( self, sampling_params=sampling_params, ) + else: if self.backend.is_chat_model: # Speculative execution on models with only chat interface. @@ -587,8 +587,11 @@ class StreamExecutor: else: # Speculative execution on models with completion interface comp, meta_info = self._spec_gen(sampling_params) - - self.text_ += comp + if isinstance(comp, list): + self.text_ += comp[0] + else: + assert isinstance(comp, str) + self.text_ += comp self.variables[name] = comp self.meta_info[name] = meta_info @@ -747,6 +750,7 @@ class StreamExecutor: for item in [ "max_new_tokens", "min_new_tokens", + "n", "stop", "stop_token_ids", "temperature", diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 0431d2c6b..f1775bce1 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg class SglSamplingParams: max_new_tokens: int = 128 min_new_tokens: int = 0 + n: int = 1 stop: Union[str, List[str]] = () stop_token_ids: Optional[List[int]] = () temperature: float = 1.0 @@ -41,6 +42,7 @@ class SglSamplingParams: return SglSamplingParams( self.max_new_tokens, self.min_new_tokens, + self.n, self.stop, self.stop_token_ids, self.temperature, @@ -64,6 +66,7 @@ class SglSamplingParams: return { "max_tokens": self.max_new_tokens, "max_completion_tokens": self.max_new_tokens, + "n": self.n, "stop": self.stop or None, "temperature": self.temperature, "top_p": self.top_p, @@ -117,6 +120,7 @@ class SglSamplingParams: return { "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, + "n": self.n, "stop": self.stop, "stop_token_ids": self.stop_token_ids, "temperature": self.temperature, @@ -154,6 +158,7 @@ class SglFunction: self, *args, max_new_tokens: int = 128, + n: int = 1, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, @@ -182,6 +187,7 @@ class SglFunction: default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, + n=n, stop=stop, stop_token_ids=stop_token_ids, temperature=temperature, @@ -212,6 +218,7 @@ class SglFunction: batch_kwargs, *, max_new_tokens: int = 128, + n: int = 1, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: float = 1.0, @@ -257,6 +264,7 @@ class SglFunction: default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, + n=n, stop=stop, stop_token_ids=stop_token_ids, temperature=temperature, @@ -440,6 +448,7 @@ class SglGen(SglExpr): name: Optional[str] = None, max_new_tokens: Optional[int] = None, min_new_tokens: Optional[int] = None, + n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -463,6 +472,7 @@ class SglGen(SglExpr): self.sampling_params = SglSamplingParams( max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, + n=n, stop=stop, stop_token_ids=stop_token_ids, temperature=temperature,