Feat/support rerank (#6058)

This commit is contained in:
woodx
2025-06-17 01:50:01 +08:00
committed by GitHub
parent 91a066ec6a
commit e30ef368ab
20 changed files with 684 additions and 30 deletions

View File

@@ -17,7 +17,9 @@ import requests
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
from sglang.test.test_utils import (
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
self.assertEqual(cm.exception.status_code, 400)
class TestOpenAIV1Rerank(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.score_tolerance = 1e-2
# Configure embedding-specific args
other_args = [
"--is-embedding",
"--enable-metrics",
"--disable-radix-cache",
"--chunked-prefill-size",
"-1",
"--attention-backend",
"torch_native",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=other_args,
)
cls.base_url += "/v1/rerank"
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_rerank(self, query, docs):
response = requests.post(
self.base_url,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
json={"query": query, "documents": docs},
)
return response.json()
def test_rerank_single(self):
"""Test single rerank request"""
query = TEST_RERANK_QUERY_DOCS[0]["query"]
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 1)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
def test_rerank_batch(self):
"""Test batch rerank request"""
query = TEST_RERANK_QUERY_DOCS[1]["query"]
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
response = self.run_rerank(query, docs)
self.assertEqual(len(response), 2)
self.assertTrue(isinstance(response[0]["score"], float))
self.assertTrue(isinstance(response[1]["score"], float))
self.assertTrue(isinstance(response[0]["document"], str))
self.assertTrue(isinstance(response[1]["document"], str))
self.assertTrue(isinstance(response[0]["index"], int))
self.assertTrue(isinstance(response[1]["index"], int))
class TestOpenAIServerIgnoreEOS(CustomTestCase):
@classmethod
def setUpClass(cls):