Support extra field regex in OpenAI API (#172)
This commit is contained in:
@@ -36,6 +36,9 @@ class CompletionRequest(BaseModel):
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
regex: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
@@ -119,6 +122,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
user: Optional[str] = None
|
||||
best_of: Optional[int] = None
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
regex: Optional[str] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
|
||||
@@ -151,6 +151,7 @@ async def v1_completions(raw_request: Request):
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"regex": request.regex,
|
||||
},
|
||||
return_logprob=request.logprobs is not None,
|
||||
stream=request.stream,
|
||||
@@ -304,6 +305,7 @@ async def v1_chat_completions(raw_request: Request):
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"regex": request.regex,
|
||||
},
|
||||
stream=request.stream,
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ The capital of Japan is Tokyo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import openai
|
||||
|
||||
@@ -151,6 +152,29 @@ def test_chat_completion_stream(args):
|
||||
print()
|
||||
|
||||
|
||||
def test_regex(args):
|
||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||
|
||||
regex = (r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": "[\w\d\s]+"\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
print(json.loads(text))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
|
||||
@@ -169,5 +193,6 @@ if __name__ == "__main__":
|
||||
test_completion_stream(args, echo=True, logprobs=True)
|
||||
test_chat_completion(args)
|
||||
test_chat_completion_stream(args)
|
||||
test_regex(args)
|
||||
if args.test_image:
|
||||
test_chat_completion_image(args)
|
||||
|
||||
Reference in New Issue
Block a user