Test regex in vision api (#926)
This commit is contained in:
@@ -139,10 +139,12 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the capital of France? Answer in a few words.",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
n=parallel_sample_num,
|
||||
@@ -178,7 +180,6 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
stream=True,
|
||||
|
||||
@@ -46,25 +46,71 @@ class TestOpenAIVisionServer(unittest.TestCase):
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/assets/logo.png?raw=true"
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
},
|
||||
},
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this image in a very short sentence.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
assert "car" in text or "taxi" in text, text
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "color": "[\w]+",\n"""
|
||||
+ r""" "number_of_cars": [\d]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this image in the JSON format.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
assert isinstance(js_obj["color"], str)
|
||||
assert isinstance(js_obj["number_of_cars"], int)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(warnings="ignore")
|
||||
|
||||
Reference in New Issue
Block a user