refactor(test): reorganize OpenAI test file structure (#7408)
This commit is contained in:
0
test/srt/openai_server/__init__.py
Normal file
0
test/srt/openai_server/__init__.py
Normal file
0
test/srt/openai_server/basic/__init__.py
Normal file
0
test/srt/openai_server/basic/__init__.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
97
test/srt/openai_server/basic/test_openai_embedding.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIEmbedding(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
|
||||||
|
# Configure embedding-specific args
|
||||||
|
other_args = ["--is-embedding", "--enable-metrics"]
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_embedding_single(self):
|
||||||
|
"""Test single embedding request"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(model=self.model, input="Hello world")
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
|
||||||
|
def test_embedding_batch(self):
|
||||||
|
"""Test batch embedding request"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=self.model, input=["Hello world", "Test text"]
|
||||||
|
)
|
||||||
|
self.assertEqual(len(response.data), 2)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
self.assertTrue(len(response.data[1].embedding) > 0)
|
||||||
|
|
||||||
|
def test_embedding_single_batch_str(self):
|
||||||
|
"""Test embedding with a List[str] and length equals to 1"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
|
||||||
|
def test_embedding_single_int_list(self):
|
||||||
|
"""Test embedding with a List[int] or List[List[int]]]"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
response = client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
||||||
|
)
|
||||||
|
self.assertEqual(len(response.data), 1)
|
||||||
|
self.assertTrue(len(response.data[0].embedding) > 0)
|
||||||
|
|
||||||
|
def test_empty_string_embedding(self):
|
||||||
|
"""Test embedding an empty string."""
|
||||||
|
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
# Text embedding example with empty string
|
||||||
|
text = ""
|
||||||
|
# Expect a BadRequestError for empty input
|
||||||
|
with self.assertRaises(openai.BadRequestError) as cm:
|
||||||
|
client.embeddings.create(
|
||||||
|
model=self.model,
|
||||||
|
input=text,
|
||||||
|
)
|
||||||
|
# check the status code
|
||||||
|
self.assertEqual(cm.exception.status_code, 400)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -1,14 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_completion_stream
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
|
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
|
python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -20,7 +18,6 @@ from sglang.srt.utils import kill_process_tree
|
|||||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
|
||||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
@@ -508,87 +505,6 @@ class TestOpenAIServerEBNF(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIEmbedding(CustomTestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.api_key = "sk-123456"
|
|
||||||
|
|
||||||
# Configure embedding-specific args
|
|
||||||
other_args = ["--is-embedding", "--enable-metrics"]
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
api_key=cls.api_key,
|
|
||||||
other_args=other_args,
|
|
||||||
)
|
|
||||||
cls.base_url += "/v1"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def test_embedding_single(self):
|
|
||||||
"""Test single embedding request"""
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
response = client.embeddings.create(model=self.model, input="Hello world")
|
|
||||||
self.assertEqual(len(response.data), 1)
|
|
||||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
||||||
|
|
||||||
def test_embedding_batch(self):
|
|
||||||
"""Test batch embedding request"""
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
response = client.embeddings.create(
|
|
||||||
model=self.model, input=["Hello world", "Test text"]
|
|
||||||
)
|
|
||||||
self.assertEqual(len(response.data), 2)
|
|
||||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
||||||
self.assertTrue(len(response.data[1].embedding) > 0)
|
|
||||||
|
|
||||||
def test_embedding_single_batch_str(self):
|
|
||||||
"""Test embedding with a List[str] and length equals to 1"""
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
response = client.embeddings.create(model=self.model, input=["Hello world"])
|
|
||||||
self.assertEqual(len(response.data), 1)
|
|
||||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
||||||
|
|
||||||
def test_embedding_single_int_list(self):
|
|
||||||
"""Test embedding with a List[int] or List[List[int]]]"""
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
response = client.embeddings.create(
|
|
||||||
model=self.model,
|
|
||||||
input=[[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061]],
|
|
||||||
)
|
|
||||||
self.assertEqual(len(response.data), 1)
|
|
||||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
||||||
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
response = client.embeddings.create(
|
|
||||||
model=self.model,
|
|
||||||
input=[15339, 314, 703, 284, 612, 262, 10658, 10188, 286, 2061],
|
|
||||||
)
|
|
||||||
self.assertEqual(len(response.data), 1)
|
|
||||||
self.assertTrue(len(response.data[0].embedding) > 0)
|
|
||||||
|
|
||||||
def test_empty_string_embedding(self):
|
|
||||||
"""Test embedding an empty string."""
|
|
||||||
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
|
|
||||||
# Text embedding example with empty string
|
|
||||||
text = ""
|
|
||||||
# Expect a BadRequestError for empty input
|
|
||||||
with self.assertRaises(openai.BadRequestError) as cm:
|
|
||||||
client.embeddings.create(
|
|
||||||
model=self.model,
|
|
||||||
input=text,
|
|
||||||
)
|
|
||||||
# check the status code
|
|
||||||
self.assertEqual(cm.exception.status_code, 400)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIV1Rerank(CustomTestCase):
|
class TestOpenAIV1Rerank(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -660,79 +576,6 @@ class TestOpenAIV1Rerank(CustomTestCase):
|
|||||||
self.assertTrue(isinstance(response[1]["index"], int))
|
self.assertTrue(isinstance(response[1]["index"], int))
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.api_key = "sk-123456"
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
api_key=cls.api_key,
|
|
||||||
)
|
|
||||||
cls.base_url += "/v1"
|
|
||||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def test_ignore_eos(self):
|
|
||||||
"""
|
|
||||||
Test that ignore_eos=True allows generation to continue beyond EOS token
|
|
||||||
and reach the max_tokens limit.
|
|
||||||
"""
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
|
|
||||||
max_tokens = 200
|
|
||||||
|
|
||||||
response_default = client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Count from 1 to 20."},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
extra_body={"ignore_eos": False},
|
|
||||||
)
|
|
||||||
|
|
||||||
response_ignore_eos = client.chat.completions.create(
|
|
||||||
model=self.model,
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "Count from 1 to 20."},
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
extra_body={"ignore_eos": True},
|
|
||||||
)
|
|
||||||
|
|
||||||
default_tokens = len(
|
|
||||||
self.tokenizer.encode(response_default.choices[0].message.content)
|
|
||||||
)
|
|
||||||
ignore_eos_tokens = len(
|
|
||||||
self.tokenizer.encode(response_ignore_eos.choices[0].message.content)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if ignore_eos resulted in more tokens or exactly max_tokens
|
|
||||||
# The ignore_eos response should either:
|
|
||||||
# 1. Have more tokens than the default response (if default stopped at EOS before max_tokens)
|
|
||||||
# 2. Have exactly max_tokens (if it reached the max_tokens limit)
|
|
||||||
self.assertTrue(
|
|
||||||
ignore_eos_tokens > default_tokens or ignore_eos_tokens >= max_tokens,
|
|
||||||
f"ignore_eos did not generate more tokens: {ignore_eos_tokens} vs {default_tokens}",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
response_ignore_eos.choices[0].finish_reason,
|
|
||||||
"length",
|
|
||||||
f"Expected finish_reason='length' for ignore_eos=True, got {response_ignore_eos.choices[0].finish_reason}",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIV1Score(CustomTestCase):
|
class TestOpenAIV1Score(CustomTestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
@@ -1,8 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
|
Unit tests for the OpenAIServingEmbedding class from serving_embedding.py.
|
||||||
|
|
||||||
These tests ensure that the embedding serving implementation maintains compatibility
|
|
||||||
with the original adapter.py functionality and follows OpenAI API specifications.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
0
test/srt/openai_server/features/__init__.py
Normal file
0
test/srt/openai_server/features/__init__.py
Normal file
@@ -97,7 +97,7 @@ class TestCacheReport(CustomTestCase):
|
|||||||
)
|
)
|
||||||
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
first_cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||||
# assert int(response.usage.cached_tokens) == 0
|
# assert int(response.usage.cached_tokens) == 0
|
||||||
assert first_cached_tokens < self.min_cached
|
assert first_cached_tokens <= self.min_cached
|
||||||
response = self.run_openai(message)
|
response = self.run_openai(message)
|
||||||
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
cached_tokens = int(response.usage.prompt_tokens_details.cached_tokens)
|
||||||
print(f"openai second request cached_tokens: {cached_tokens}")
|
print(f"openai second request cached_tokens: {cached_tokens}")
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
python3 -m unittest test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
|
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_with_reasoning
|
||||||
python3 -m unittest test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
|
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_chat_completion_without_reasoning
|
||||||
python3 -m unittest test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
|
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_with_reasoning
|
||||||
python3 -m unittest test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
|
python3 -m unittest openai_server.features.test_enable_thinking.TestEnableThinking.test_stream_chat_completion_without_reasoning
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -13,8 +13,10 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
|
DEFAULT_ENABLE_THINKING_MODEL_NAME_FOR_TEST,
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
|
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
|
||||||
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
|
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
|
||||||
python3 -m unittest test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
|
python3 -m unittest openai_server.features.test_json_constrained.TestJSONConstrainedLLGuidanceBackend.test_json_generate
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_json_mode.TestJSONModeOutlines.test_json_mode_response
|
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_response
|
||||||
python3 -m unittest test_json_mode.TestJSONModeOutlines.test_json_mode_with_streaming
|
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeOutlines.test_json_mode_with_streaming
|
||||||
|
|
||||||
python3 -m unittest test_json_mode.TestJSONModeXGrammar.test_json_mode_response
|
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_response
|
||||||
python3 -m unittest test_json_mode.TestJSONModeXGrammar.test_json_mode_with_streaming
|
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeXGrammar.test_json_mode_with_streaming
|
||||||
|
|
||||||
python3 -m unittest test_json_mode.TestJSONModeLLGuidance.test_json_mode_response
|
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_response
|
||||||
python3 -m unittest test_json_mode.TestJSONModeLLGuidance.test_json_mode_with_streaming
|
python3 -m unittest openai_server.features.test_json_mode.TestJSONModeLLGuidance.test_json_mode_with_streaming
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
98
test/srt/openai_server/features/test_openai_server_ebnf.py
Normal file
98
test/srt/openai_server/features/test_openai_server_ebnf.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# EBNF Test Class: TestOpenAIServerEBNF
|
||||||
|
# Launches the server with xgrammar, has only EBNF tests
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
class TestOpenAIServerEBNF(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
|
||||||
|
# passing xgrammar specifically
|
||||||
|
other_args = ["--grammar-backend", "xgrammar"]
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_ebnf(self):
|
||||||
|
"""
|
||||||
|
Ensure we can pass `ebnf` to the local openai server
|
||||||
|
and that it enforces the grammar.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
ebnf_grammar = r"""
|
||||||
|
root ::= "Hello" | "Hi" | "Hey"
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$")
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful EBNF test bot."},
|
||||||
|
{"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
extra_body={"ebnf": ebnf_grammar},
|
||||||
|
)
|
||||||
|
text = response.choices[0].message.content.strip()
|
||||||
|
self.assertTrue(len(text) > 0, "Got empty text from EBNF generation")
|
||||||
|
self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices")
|
||||||
|
|
||||||
|
def test_ebnf_strict_json(self):
|
||||||
|
"""
|
||||||
|
A stricter EBNF that produces exactly {"name":"Alice"} format
|
||||||
|
with no trailing punctuation or extra fields.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
ebnf_grammar = r"""
|
||||||
|
root ::= "{" pair "}"
|
||||||
|
pair ::= "\"name\"" ":" string
|
||||||
|
string ::= "\"" [A-Za-z]+ "\""
|
||||||
|
"""
|
||||||
|
pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$')
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "EBNF mini-JSON generator."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Generate single key JSON with only letters.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=64,
|
||||||
|
extra_body={"ebnf": ebnf_grammar},
|
||||||
|
)
|
||||||
|
text = response.choices[0].message.content.strip()
|
||||||
|
self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test")
|
||||||
|
self.assertRegex(
|
||||||
|
text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape"
|
||||||
|
)
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_false
|
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_false
|
||||||
python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true
|
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true
|
||||||
python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true_stream_reasoning_false
|
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_streaming_separate_reasoning_true_stream_reasoning_false
|
||||||
python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_false
|
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_false
|
||||||
python3 -m unittest test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_true
|
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentAPI.test_nonstreaming_separate_reasoning_true
|
||||||
python3 -m unittest test_reasoning_content.TestReasoningContentStartup.test_nonstreaming
|
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_nonstreaming
|
||||||
python3 -m unittest test_reasoning_content.TestReasoningContentStartup.test_streaming
|
python3 -m unittest openai_server.features.test_reasoning_content.TestReasoningContentStartup.test_streaming
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
0
test/srt/openai_server/function_call/__init__.py
Normal file
0
test/srt/openai_server/function_call/__init__.py
Normal file
@@ -2,9 +2,12 @@
|
|||||||
Test script for tool_choice functionality in SGLang
|
Test script for tool_choice functionality in SGLang
|
||||||
Tests: required, auto, and specific function choices in both streaming and non-streaming modes
|
Tests: required, auto, and specific function choices in both streaming and non-streaming modes
|
||||||
|
|
||||||
python3 -m unittest test_tool_choice.TestToolChoice
|
# To run the tests, use the following command:
|
||||||
|
#
|
||||||
|
# python3 -m unittest openai_server.function_call.test_tool_choice
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
0
test/srt/openai_server/validation/__init__.py
Normal file
0
test/srt/openai_server/validation/__init__.py
Normal file
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_completion
|
python3 -m unittest openai_server.validation.test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_completion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -0,0 +1,84 @@
|
|||||||
|
import openai
|
||||||
|
|
||||||
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.api_key = "sk-123456"
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key=cls.api_key,
|
||||||
|
)
|
||||||
|
cls.base_url += "/v1"
|
||||||
|
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_ignore_eos(self):
|
||||||
|
"""
|
||||||
|
Test that ignore_eos=True allows generation to continue beyond EOS token
|
||||||
|
and reach the max_tokens limit.
|
||||||
|
"""
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
max_tokens = 200
|
||||||
|
|
||||||
|
response_default = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Count from 1 to 20."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
extra_body={"ignore_eos": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
response_ignore_eos = client.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Count from 1 to 20."},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
extra_body={"ignore_eos": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
default_tokens = len(
|
||||||
|
self.tokenizer.encode(response_default.choices[0].message.content)
|
||||||
|
)
|
||||||
|
ignore_eos_tokens = len(
|
||||||
|
self.tokenizer.encode(response_ignore_eos.choices[0].message.content)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if ignore_eos resulted in more tokens or exactly max_tokens
|
||||||
|
# The ignore_eos response should either:
|
||||||
|
# 1. Have more tokens than the default response (if default stopped at EOS before max_tokens)
|
||||||
|
# 2. Have exactly max_tokens (if it reached the max_tokens limit)
|
||||||
|
self.assertTrue(
|
||||||
|
ignore_eos_tokens > default_tokens or ignore_eos_tokens >= max_tokens,
|
||||||
|
f"ignore_eos did not generate more tokens: {ignore_eos_tokens} vs {default_tokens}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
response_ignore_eos.choices[0].finish_reason,
|
||||||
|
"length",
|
||||||
|
f"Expected finish_reason='length' for ignore_eos=True, got {response_ignore_eos.choices[0].finish_reason}",
|
||||||
|
)
|
||||||
@@ -29,10 +29,25 @@ suites = {
|
|||||||
TestFile("models/test_reward_models.py", 132),
|
TestFile("models/test_reward_models.py", 132),
|
||||||
TestFile("models/test_vlm_models.py", 437),
|
TestFile("models/test_vlm_models.py", 437),
|
||||||
TestFile("models/test_transformers_models.py", 320),
|
TestFile("models/test_transformers_models.py", 320),
|
||||||
TestFile("openai/test_protocol.py", 10),
|
TestFile("openai_server/basic/test_protocol.py", 10),
|
||||||
TestFile("openai/test_serving_chat.py", 10),
|
TestFile("openai_server/basic/test_serving_chat.py", 10),
|
||||||
TestFile("openai/test_serving_completions.py", 10),
|
TestFile("openai_server/basic/test_serving_completions.py", 10),
|
||||||
TestFile("openai/test_serving_embedding.py", 10),
|
TestFile("openai_server/basic/test_serving_embedding.py", 10),
|
||||||
|
TestFile("openai_server/basic/test_openai_embedding.py", 141),
|
||||||
|
TestFile("openai_server/basic/test_openai_server.py", 149),
|
||||||
|
TestFile("openai_server/features/test_cache_report.py", 100),
|
||||||
|
TestFile("openai_server/features/test_enable_thinking.py", 70),
|
||||||
|
TestFile("openai_server/features/test_json_constrained.py", 98),
|
||||||
|
TestFile("openai_server/features/test_json_mode.py", 90),
|
||||||
|
TestFile("openai_server/features/test_openai_server_ebnf.py", 95),
|
||||||
|
TestFile("openai_server/features/test_openai_server_hidden_states.py", 240),
|
||||||
|
TestFile("openai_server/features/test_reasoning_content.py", 89),
|
||||||
|
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
|
||||||
|
TestFile("openai_server/function_call/test_tool_choice.py", 226),
|
||||||
|
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
|
||||||
|
TestFile("openai_server/validation/test_matched_stop.py", 60),
|
||||||
|
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
|
||||||
|
TestFile("openai_server/validation/test_request_length_validation.py", 31),
|
||||||
TestFile("test_abort.py", 51),
|
TestFile("test_abort.py", 51),
|
||||||
TestFile("test_block_int8.py", 22),
|
TestFile("test_block_int8.py", 22),
|
||||||
TestFile("test_create_kvindices.py", 2),
|
TestFile("test_create_kvindices.py", 2),
|
||||||
@@ -40,8 +55,6 @@ suites = {
|
|||||||
TestFile("test_eagle_infer_a.py", 370),
|
TestFile("test_eagle_infer_a.py", 370),
|
||||||
TestFile("test_eagle_infer_b.py", 270),
|
TestFile("test_eagle_infer_b.py", 270),
|
||||||
TestFile("test_ebnf_constrained.py", 108),
|
TestFile("test_ebnf_constrained.py", 108),
|
||||||
TestFile("test_enable_thinking.py", 70),
|
|
||||||
TestFile("test_embedding_openai_server.py", 141),
|
|
||||||
TestFile("test_eval_fp8_accuracy.py", 303),
|
TestFile("test_eval_fp8_accuracy.py", 303),
|
||||||
TestFile("test_fa3.py", 376),
|
TestFile("test_fa3.py", 376),
|
||||||
# TestFile("test_flashmla.py", 352),
|
# TestFile("test_flashmla.py", 352),
|
||||||
@@ -54,8 +67,6 @@ suites = {
|
|||||||
TestFile("test_int8_kernel.py", 8),
|
TestFile("test_int8_kernel.py", 8),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
TestFile("test_jinja_template_utils.py", 1),
|
TestFile("test_jinja_template_utils.py", 1),
|
||||||
TestFile("test_json_constrained.py", 98),
|
|
||||||
TestFile("test_large_max_new_tokens.py", 41),
|
|
||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
TestFile("test_mla.py", 167),
|
TestFile("test_mla.py", 167),
|
||||||
TestFile("test_mla_deepseek_v3.py", 342),
|
TestFile("test_mla_deepseek_v3.py", 342),
|
||||||
@@ -64,22 +75,16 @@ suites = {
|
|||||||
TestFile("test_mla_fp8.py", 93),
|
TestFile("test_mla_fp8.py", 93),
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
TestFile("test_no_overlap_scheduler.py", 234),
|
TestFile("test_no_overlap_scheduler.py", 234),
|
||||||
TestFile("test_openai_function_calling.py", 60),
|
|
||||||
TestFile("test_openai_server.py", 149),
|
|
||||||
TestFile("test_openai_server_hidden_states.py", 240),
|
|
||||||
TestFile("test_penalty.py", 41),
|
TestFile("test_penalty.py", 41),
|
||||||
TestFile("test_page_size.py", 60),
|
TestFile("test_page_size.py", 60),
|
||||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||||
TestFile("test_radix_attention.py", 105),
|
TestFile("test_radix_attention.py", 105),
|
||||||
TestFile("test_reasoning_content.py", 89),
|
|
||||||
TestFile("test_regex_constrained.py", 64),
|
TestFile("test_regex_constrained.py", 64),
|
||||||
TestFile("test_request_length_validation.py", 31),
|
|
||||||
TestFile("test_retract_decode.py", 54),
|
TestFile("test_retract_decode.py", 54),
|
||||||
TestFile("test_server_args.py", 1),
|
TestFile("test_server_args.py", 1),
|
||||||
TestFile("test_skip_tokenizer_init.py", 117),
|
TestFile("test_skip_tokenizer_init.py", 117),
|
||||||
TestFile("test_srt_engine.py", 261),
|
TestFile("test_srt_engine.py", 261),
|
||||||
TestFile("test_srt_endpoint.py", 130),
|
TestFile("test_srt_endpoint.py", 130),
|
||||||
TestFile("test_tool_choice.py", 226),
|
|
||||||
TestFile("test_torch_compile.py", 76),
|
TestFile("test_torch_compile.py", 76),
|
||||||
TestFile("test_torch_compile_moe.py", 172),
|
TestFile("test_torch_compile_moe.py", 172),
|
||||||
TestFile("test_torch_native_attention_backend.py", 123),
|
TestFile("test_torch_native_attention_backend.py", 123),
|
||||||
@@ -107,15 +112,32 @@ suites = {
|
|||||||
TestFile("test_torch_compile_moe.py", 172),
|
TestFile("test_torch_compile_moe.py", 172),
|
||||||
TestFile("models/test_qwen_models.py", 82),
|
TestFile("models/test_qwen_models.py", 82),
|
||||||
TestFile("models/test_reward_models.py", 132),
|
TestFile("models/test_reward_models.py", 132),
|
||||||
|
TestFile("openai_server/basic/test_openai_embedding.py", 141),
|
||||||
|
TestFile("openai_server/basic/test_openai_server.py", 149),
|
||||||
|
TestFile("openai_server/basic/test_protocol.py", 10),
|
||||||
|
TestFile("openai_server/basic/test_serving_chat.py", 10),
|
||||||
|
TestFile("openai_server/basic/test_serving_completions.py", 10),
|
||||||
|
TestFile("openai_server/basic/test_serving_embedding.py", 10),
|
||||||
TestFile("test_abort.py", 51),
|
TestFile("test_abort.py", 51),
|
||||||
TestFile("test_block_int8.py", 22),
|
TestFile("test_block_int8.py", 22),
|
||||||
TestFile("test_create_kvindices.py", 2),
|
TestFile("test_create_kvindices.py", 2),
|
||||||
TestFile("test_chunked_prefill.py", 313),
|
TestFile("test_chunked_prefill.py", 313),
|
||||||
TestFile("test_embedding_openai_server.py", 141),
|
|
||||||
TestFile("test_eval_fp8_accuracy.py", 303),
|
TestFile("test_eval_fp8_accuracy.py", 303),
|
||||||
TestFile("test_function_call_parser.py", 10),
|
TestFile("test_function_call_parser.py", 10),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
TestFile("test_large_max_new_tokens.py", 41),
|
TestFile("openai_server/features/test_cache_report.py", 100),
|
||||||
|
TestFile("openai_server/features/test_enable_thinking.py", 70),
|
||||||
|
TestFile("openai_server/features/test_json_constrained.py", 98),
|
||||||
|
TestFile("openai_server/features/test_json_mode.py", 90),
|
||||||
|
TestFile("openai_server/features/test_openai_server_ebnf.py", 95),
|
||||||
|
TestFile("openai_server/features/test_openai_server_hidden_states.py", 240),
|
||||||
|
TestFile("openai_server/features/test_reasoning_content.py", 89),
|
||||||
|
TestFile("openai_server/function_call/test_openai_function_calling.py", 60),
|
||||||
|
TestFile("openai_server/function_call/test_tool_choice.py", 226),
|
||||||
|
TestFile("openai_server/validation/test_large_max_new_tokens.py", 41),
|
||||||
|
TestFile("openai_server/validation/test_matched_stop.py", 60),
|
||||||
|
TestFile("openai_server/validation/test_openai_server_ignore_eos.py", 85),
|
||||||
|
TestFile("openai_server/validation/test_request_length_validation.py", 31),
|
||||||
TestFile("test_metrics.py", 32),
|
TestFile("test_metrics.py", 32),
|
||||||
TestFile("test_no_chunked_prefill.py", 108),
|
TestFile("test_no_chunked_prefill.py", 108),
|
||||||
TestFile("test_no_overlap_scheduler.py", 234),
|
TestFile("test_no_overlap_scheduler.py", 234),
|
||||||
@@ -123,9 +145,6 @@ suites = {
|
|||||||
TestFile("test_page_size.py", 60),
|
TestFile("test_page_size.py", 60),
|
||||||
TestFile("test_pytorch_sampling_backend.py", 66),
|
TestFile("test_pytorch_sampling_backend.py", 66),
|
||||||
TestFile("test_radix_attention.py", 105),
|
TestFile("test_radix_attention.py", 105),
|
||||||
TestFile("test_reasoning_content.py", 89),
|
|
||||||
TestFile("test_enable_thinking.py", 70),
|
|
||||||
TestFile("test_request_length_validation.py", 31),
|
|
||||||
TestFile("test_retract_decode.py", 54),
|
TestFile("test_retract_decode.py", 54),
|
||||||
TestFile("test_server_args.py", 1),
|
TestFile("test_server_args.py", 1),
|
||||||
TestFile("test_skip_tokenizer_init.py", 117),
|
TestFile("test_skip_tokenizer_init.py", 117),
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
import unittest
|
|
||||||
|
|
||||||
import openai
|
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
||||||
from sglang.srt.utils import kill_process_tree
|
|
||||||
from sglang.test.test_utils import (
|
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
DEFAULT_URL_FOR_TEST,
|
|
||||||
CustomTestCase,
|
|
||||||
popen_launch_server,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestOpenAIServer(CustomTestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls.model = "intfloat/e5-mistral-7b-instruct"
|
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
|
||||||
cls.api_key = "sk-123456"
|
|
||||||
cls.process = popen_launch_server(
|
|
||||||
cls.model,
|
|
||||||
cls.base_url,
|
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
||||||
api_key=cls.api_key,
|
|
||||||
)
|
|
||||||
cls.base_url += "/v1"
|
|
||||||
cls.tokenizer = get_tokenizer(cls.model)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
kill_process_tree(cls.process.pid)
|
|
||||||
|
|
||||||
def run_embedding(self, use_list_input, token_input):
|
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
prompt = "The capital of France is"
|
|
||||||
if token_input:
|
|
||||||
prompt_input = self.tokenizer.encode(prompt)
|
|
||||||
num_prompt_tokens = len(prompt_input)
|
|
||||||
else:
|
|
||||||
prompt_input = prompt
|
|
||||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
|
||||||
|
|
||||||
if use_list_input:
|
|
||||||
prompt_arg = [prompt_input] * 2
|
|
||||||
num_prompts = len(prompt_arg)
|
|
||||||
num_prompt_tokens *= num_prompts
|
|
||||||
else:
|
|
||||||
prompt_arg = prompt_input
|
|
||||||
num_prompts = 1
|
|
||||||
|
|
||||||
response = client.embeddings.create(
|
|
||||||
input=prompt_arg,
|
|
||||||
model=self.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(response.data) == num_prompts
|
|
||||||
assert isinstance(response.data, list)
|
|
||||||
assert response.data[0].embedding
|
|
||||||
assert response.data[0].index is not None
|
|
||||||
assert response.data[0].object == "embedding"
|
|
||||||
assert response.model == self.model
|
|
||||||
assert response.object == "list"
|
|
||||||
assert (
|
|
||||||
response.usage.prompt_tokens == num_prompt_tokens
|
|
||||||
), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}"
|
|
||||||
assert (
|
|
||||||
response.usage.total_tokens == num_prompt_tokens
|
|
||||||
), f"{response.usage.total_tokens} vs {num_prompt_tokens}"
|
|
||||||
|
|
||||||
def run_batch(self):
|
|
||||||
# FIXME: not implemented
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_embedding(self):
|
|
||||||
# TODO: the fields of encoding_format, dimensions, user are skipped
|
|
||||||
# TODO: support use_list_input
|
|
||||||
for use_list_input in [False, True]:
|
|
||||||
for token_input in [False, True]:
|
|
||||||
self.run_embedding(use_list_input, token_input)
|
|
||||||
|
|
||||||
def test_batch(self):
|
|
||||||
self.run_batch()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Unit tests for OpenAI adapter utils.
|
Unit tests for Jinja chat template utils.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|||||||
Reference in New Issue
Block a user