Feat/support rerank (#6058)
This commit is contained in:
@@ -17,7 +17,9 @@ import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
|
||||
self.assertEqual(cm.exception.status_code, 400)
|
||||
|
||||
|
||||
class TestOpenAIV1Rerank(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.score_tolerance = 1e-2
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = [
|
||||
"--is-embedding",
|
||||
"--enable-metrics",
|
||||
"--disable-radix-cache",
|
||||
"--chunked-prefill-size",
|
||||
"-1",
|
||||
"--attention-backend",
|
||||
"torch_native",
|
||||
]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1/rerank"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_rerank(self, query, docs):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"query": query, "documents": docs},
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def test_rerank_single(self):
|
||||
"""Test single rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[0]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 1)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
|
||||
def test_rerank_batch(self):
|
||||
"""Test batch rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[1]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 2)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[1]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[1]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
self.assertTrue(isinstance(response[1]["index"], int))
|
||||
|
||||
|
||||
class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user