Feat/support rerank (#6058)
This commit is contained in:
91
test/srt/models/test_cross_encoder_models.py
Normal file
91
test/srt/models/test_cross_encoder_models.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import multiprocessing as mp
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS, HFRunner, SRTRunner
|
||||
from sglang.test.test_utils import CustomTestCase, is_in_ci
|
||||
|
||||
MODELS = [
|
||||
("cross-encoder/ms-marco-MiniLM-L6-v2", 1, 1e-2),
|
||||
("BAAI/bge-reranker-v2-m3", 1, 1e-2),
|
||||
]
|
||||
ATTENTION_BACKEND = ["torch_native", "triton"]
|
||||
|
||||
TORCH_DTYPES = [torch.float32]
|
||||
|
||||
|
||||
class TestCrossEncoderModels(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
def assert_close_prefill_logits(
|
||||
self,
|
||||
prompts,
|
||||
model_path,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
score_tolerance,
|
||||
attention_backend,
|
||||
) -> None:
|
||||
with HFRunner(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="cross_encoder",
|
||||
) as hf_runner:
|
||||
hf_scores = hf_runner.forward(prompts).scores
|
||||
|
||||
with SRTRunner(
|
||||
model_path,
|
||||
tp_size=tp_size,
|
||||
torch_dtype=torch_dtype,
|
||||
model_type="cross_encoder",
|
||||
attention_backend=attention_backend,
|
||||
chunked_prefill_size=-1,
|
||||
disable_radix_cache=True,
|
||||
) as srt_runner:
|
||||
srt_scores = srt_runner.forward(prompts).scores
|
||||
|
||||
for i in range(len(srt_scores)):
|
||||
score_difference = abs(hf_scores[i] - srt_scores[i])
|
||||
|
||||
assert (
|
||||
score_difference < score_tolerance
|
||||
), "cross encoder scores are not all close"
|
||||
|
||||
def preprocess_prompts(self, prompt):
|
||||
processed_prompts = []
|
||||
query = prompt["query"]
|
||||
documents = prompt["documents"]
|
||||
for document in documents:
|
||||
processed_prompts.append([query, document])
|
||||
|
||||
return processed_prompts
|
||||
|
||||
def test_prefill_logits(self):
|
||||
models_to_test = MODELS
|
||||
|
||||
if is_in_ci():
|
||||
models_to_test = [random.choice(MODELS)]
|
||||
|
||||
for model, tp_size, prefill_tolerance in models_to_test:
|
||||
for attention_backend in ATTENTION_BACKEND:
|
||||
for queryDocs in TEST_RERANK_QUERY_DOCS:
|
||||
prompts = self.preprocess_prompts(queryDocs)
|
||||
for torch_dtype in TORCH_DTYPES:
|
||||
self.assert_close_prefill_logits(
|
||||
prompts,
|
||||
model,
|
||||
tp_size,
|
||||
torch_dtype,
|
||||
prefill_tolerance,
|
||||
attention_backend,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -19,6 +19,8 @@ suites = {
|
||||
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
||||
TestFile("models/test_embedding_models.py", 73),
|
||||
# TestFile("models/test_clip_models.py", 52),
|
||||
TestFile("models/test_encoder_embedding_models.py", 100),
|
||||
TestFile("models/test_cross_encoder_models.py", 100),
|
||||
TestFile("models/test_compressed_tensors_models.py", 42),
|
||||
TestFile("models/test_generation_models.py", 103),
|
||||
# TestFile("models/test_gme_qwen_models.py", 45),
|
||||
|
||||
@@ -17,7 +17,9 @@ import requests
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
@@ -699,6 +701,77 @@ class TestOpenAIEmbedding(CustomTestCase):
|
||||
self.assertEqual(cm.exception.status_code, 400)
|
||||
|
||||
|
||||
class TestOpenAIV1Rerank(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.score_tolerance = 1e-2
|
||||
|
||||
# Configure embedding-specific args
|
||||
other_args = [
|
||||
"--is-embedding",
|
||||
"--enable-metrics",
|
||||
"--disable-radix-cache",
|
||||
"--chunked-prefill-size",
|
||||
"-1",
|
||||
"--attention-backend",
|
||||
"torch_native",
|
||||
]
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=other_args,
|
||||
)
|
||||
cls.base_url += "/v1/rerank"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_rerank(self, query, docs):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"query": query, "documents": docs},
|
||||
)
|
||||
|
||||
return response.json()
|
||||
|
||||
def test_rerank_single(self):
|
||||
"""Test single rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[0]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[0]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 1)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
|
||||
def test_rerank_batch(self):
|
||||
"""Test batch rerank request"""
|
||||
query = TEST_RERANK_QUERY_DOCS[1]["query"]
|
||||
docs = TEST_RERANK_QUERY_DOCS[1]["documents"]
|
||||
|
||||
response = self.run_rerank(query, docs)
|
||||
|
||||
self.assertEqual(len(response), 2)
|
||||
self.assertTrue(isinstance(response[0]["score"], float))
|
||||
self.assertTrue(isinstance(response[1]["score"], float))
|
||||
self.assertTrue(isinstance(response[0]["document"], str))
|
||||
self.assertTrue(isinstance(response[1]["document"], str))
|
||||
self.assertTrue(isinstance(response[0]["index"], int))
|
||||
self.assertTrue(isinstance(response[1]["index"], int))
|
||||
|
||||
|
||||
class TestOpenAIServerIgnoreEOS(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user