Test regex in vision api (#926)
This commit is contained in:
@@ -139,10 +139,12 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
{"role": "user", "content": "What is the capital of France?"},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the capital of France? Answer in a few words.",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=32,
|
|
||||||
logprobs=logprobs is not None and logprobs > 0,
|
logprobs=logprobs is not None and logprobs > 0,
|
||||||
top_logprobs=logprobs,
|
top_logprobs=logprobs,
|
||||||
n=parallel_sample_num,
|
n=parallel_sample_num,
|
||||||
@@ -178,7 +180,6 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
{"role": "user", "content": "What is the capital of France?"},
|
{"role": "user", "content": "What is the capital of France?"},
|
||||||
],
|
],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=32,
|
|
||||||
logprobs=logprobs is not None and logprobs > 0,
|
logprobs=logprobs is not None and logprobs > 0,
|
||||||
top_logprobs=logprobs,
|
top_logprobs=logprobs,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|||||||
@@ -46,25 +46,71 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": "https://github.com/sgl-project/sglang/blob/main/assets/logo.png?raw=true"
|
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{"type": "text", "text": "Describe this image"},
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image in a very short sentence.",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=32,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.choices[0].message.role == "assistant"
|
assert response.choices[0].message.role == "assistant"
|
||||||
assert isinstance(response.choices[0].message.content, str)
|
text = response.choices[0].message.content
|
||||||
|
assert isinstance(text, str)
|
||||||
|
assert "car" in text or "taxi" in text, text
|
||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
def test_regex(self):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
regex = (
|
||||||
|
r"""\{\n"""
|
||||||
|
+ r""" "color": "[\w]+",\n"""
|
||||||
|
+ r""" "number_of_cars": [\d]+\n"""
|
||||||
|
+ r"""\}"""
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Describe this image in the JSON format.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
extra_body={"regex": regex},
|
||||||
|
)
|
||||||
|
text = response.choices[0].message.content
|
||||||
|
|
||||||
|
try:
|
||||||
|
js_obj = json.loads(text)
|
||||||
|
except (TypeError, json.decoder.JSONDecodeError):
|
||||||
|
print("JSONDecodeError", text)
|
||||||
|
raise
|
||||||
|
assert isinstance(js_obj["color"], str)
|
||||||
|
assert isinstance(js_obj["number_of_cars"], int)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main(warnings="ignore")
|
unittest.main(warnings="ignore")
|
||||||
|
|||||||
Reference in New Issue
Block a user