Feat/support rerank (#6058)

This commit is contained in:
woodx
2025-06-17 01:50:01 +08:00
committed by GitHub
parent 91a066ec6a
commit e30ef368ab
20 changed files with 684 additions and 30 deletions

View File

@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
]
TEST_RERANK_QUERY_DOCS = [
{
"query": "How many people live in Berlin?",
"documents": [
"Berlin is well known for its museums.",
],
},
{
"query": "How many people live in Berlin?",
"documents": [
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
"Berlin is well known for its museums.",
],
},
]
dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
@@ -241,7 +256,7 @@ class HFRunner:
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
)
elif self.model_type == "reward":
elif self.model_type == "reward" or self.model_type == "cross_encoder":
from transformers import AutoModelForSequenceClassification
self.model = AutoModelForSequenceClassification.from_pretrained(
@@ -303,6 +318,15 @@ class HFRunner:
else:
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "cross_encoder":
inputs = self.tokenizer(
prompts, padding=True, return_tensors="pt"
).to("cuda")
scores = self.model(**inputs).logits
scores = scores.squeeze().tolist()
if not isinstance(scores, list):
scores = [scores]
out_queue.put(ModelOutput(scores=scores))
elif self.model_type == "reward":
scores = []
@@ -322,7 +346,9 @@ class HFRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
prompts: Union[
List[List[str]], List[str], List[torch.Tensor]
] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
@@ -526,7 +552,9 @@ class SRTRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
prompts: Union[
List[List[str]], List[str], List[torch.Tensor]
] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
@@ -552,6 +580,13 @@ class SRTRunner:
else:
logits = [response["embedding"]]
return ModelOutput(embed_logits=logits)
# cross encoder model
elif self.model_type == "cross_encoder":
response = self.engine.rerank(prompts)
if not isinstance(response, list):
response = [response]
scores = [x["embedding"] for x in response]
return ModelOutput(scores=scores)
# reward model
else:
response = self.engine.encode(prompts)

View File

@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
# MLA test models
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST = "cross-encoder/ms-marco-MiniLM-L6-v2"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"