From 6cc309557a5ef59e6f944cd257077da860c0d984 Mon Sep 17 00:00:00 2001 From: Chuyue Sun <33578456+ChuyueSun@users.noreply.github.com> Date: Thu, 13 Feb 2025 19:43:00 -0800 Subject: [PATCH] Add support for OpenAI API o1 model (#3363) Co-authored-by: Shan Yu --- .../quick_start/openai_example_o1.py | 57 +++++++++++++++++++ python/sglang/lang/backend/openai.py | 5 ++ python/sglang/lang/ir.py | 1 + 3 files changed, 63 insertions(+) create mode 100644 examples/frontend_language/quick_start/openai_example_o1.py diff --git a/examples/frontend_language/quick_start/openai_example_o1.py b/examples/frontend_language/quick_start/openai_example_o1.py new file mode 100644 index 000000000..2e5c14002 --- /dev/null +++ b/examples/frontend_language/quick_start/openai_example_o1.py @@ -0,0 +1,57 @@ +""" +Usage: +export OPENAI_API_KEY=sk-****** +python3 openai_example_chat.py +""" + +import sglang as sgl + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=100)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2")) + + +def single(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + + +def batch(): + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) + + for s in states: + print(s.messages()) + + +if __name__ == "__main__": + sgl.set_default_backend(sgl.OpenAI("o1")) + + # Run a single request + print("\n========== single ==========\n") + single() + # Run a batch of requests + print("\n========== batch ==========\n") + batch() diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index 4f37da79b..147437622 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -161,6 +161,10 @@ class OpenAI(BaseBackend): prompt = s.text_ kwargs = sampling_params.to_openai_kwargs() + if self.model_name.startswith("o1") or self.model_name.startswith("o3"): + kwargs.pop("max_tokens", None) + else: + kwargs.pop("max_completion_tokens", None) comp = openai_completion( client=self.client, token_usage=self.token_usage, @@ -175,6 +179,7 @@ class OpenAI(BaseBackend): ), "constrained type not supported on chat model" kwargs = sampling_params.to_openai_kwargs() kwargs.pop("stop") + comp = openai_completion( client=self.client, token_usage=self.token_usage, diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 1ae5ac106..d3a7430a8 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -63,6 +63,7 @@ class SglSamplingParams: warnings.warn("Regular expression is not supported in the OpenAI backend.") return { "max_tokens": self.max_new_tokens, + "max_completion_tokens": self.max_new_tokens, "stop": self.stop or None, "temperature": self.temperature, "top_p": self.top_p,