From 50afed4eaafeec6c87a4f120ec95742846b4130f Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 10 Feb 2024 17:21:33 -0800 Subject: [PATCH] Support extra field regex in OpenAI API (#172) --- python/sglang/srt/managers/openai_protocol.py | 6 +++++ python/sglang/srt/server.py | 2 ++ test/srt/test_openai_server.py | 25 +++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/managers/openai_protocol.py index 320eab42b..1cf1fed73 100644 --- a/python/sglang/srt/managers/openai_protocol.py +++ b/python/sglang/srt/managers/openai_protocol.py @@ -36,6 +36,9 @@ class CompletionRequest(BaseModel): logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + regex: Optional[str] = None + class CompletionResponseChoice(BaseModel): index: int @@ -119,6 +122,9 @@ class ChatCompletionRequest(BaseModel): user: Optional[str] = None best_of: Optional[int] = None + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + regex: Optional[str] = None + class ChatMessage(BaseModel): role: Optional[str] = None diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 0a9e6d24b..e5b066769 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -151,6 +151,7 @@ async def v1_completions(raw_request: Request): "top_p": request.top_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, + "regex": request.regex, }, return_logprob=request.logprobs is not None, stream=request.stream, @@ -304,6 +305,7 @@ async def v1_chat_completions(raw_request: Request): "top_p": request.top_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, + "regex": request.regex, }, stream=request.stream, ) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 01aa53e5b..2bff16960 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -14,6 +14,7 @@ The capital of Japan is Tokyo """ import argparse +import json import openai @@ -151,6 +152,29 @@ def test_chat_completion_stream(args): print() +def test_regex(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + + regex = (r"""\{\n""" + + r""" "name": "[\w]+",\n""" + + r""" "population": "[\w\d\s]+"\n""" + + r"""\}""" + ) + + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + extra_body={"regex": regex}, + ) + text = response.choices[0].message.content + print(json.loads(text)) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") @@ -169,5 +193,6 @@ if __name__ == "__main__": test_completion_stream(args, echo=True, logprobs=True) test_chat_completion(args) test_chat_completion_stream(args) + test_regex(args) if args.test_image: test_chat_completion_image(args)