Feat/support rerank (#6058)
This commit is contained in:
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
|
||||
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
||||
# "The capital of France is",
|
||||
]
|
||||
TEST_RERANK_QUERY_DOCS = [
|
||||
{
|
||||
"query": "How many people live in Berlin?",
|
||||
"documents": [
|
||||
"Berlin is well known for its museums.",
|
||||
],
|
||||
},
|
||||
{
|
||||
"query": "How many people live in Berlin?",
|
||||
"documents": [
|
||||
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
|
||||
"Berlin is well known for its museums.",
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
dirpath = os.path.dirname(__file__)
|
||||
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
||||
@@ -241,7 +256,7 @@ class HFRunner:
|
||||
self.model = _get_sentence_transformer_embedding_model(
|
||||
model_path, torch_dtype
|
||||
)
|
||||
elif self.model_type == "reward":
|
||||
elif self.model_type == "reward" or self.model_type == "cross_encoder":
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
|
||||
self.model = AutoModelForSequenceClassification.from_pretrained(
|
||||
@@ -303,6 +318,15 @@ class HFRunner:
|
||||
else:
|
||||
logits = self.model.encode(prompts).tolist()
|
||||
out_queue.put(ModelOutput(embed_logits=logits))
|
||||
elif self.model_type == "cross_encoder":
|
||||
inputs = self.tokenizer(
|
||||
prompts, padding=True, return_tensors="pt"
|
||||
).to("cuda")
|
||||
scores = self.model(**inputs).logits
|
||||
scores = scores.squeeze().tolist()
|
||||
if not isinstance(scores, list):
|
||||
scores = [scores]
|
||||
out_queue.put(ModelOutput(scores=scores))
|
||||
|
||||
elif self.model_type == "reward":
|
||||
scores = []
|
||||
@@ -322,7 +346,9 @@ class HFRunner:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
prompts: Union[
|
||||
List[List[str]], List[str], List[torch.Tensor]
|
||||
] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
@@ -526,7 +552,9 @@ class SRTRunner:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
||||
prompts: Union[
|
||||
List[List[str]], List[str], List[torch.Tensor]
|
||||
] = DEFAULT_PROMPTS,
|
||||
image_data: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 8,
|
||||
lora_paths: Optional[List[str]] = None,
|
||||
@@ -552,6 +580,13 @@ class SRTRunner:
|
||||
else:
|
||||
logits = [response["embedding"]]
|
||||
return ModelOutput(embed_logits=logits)
|
||||
# cross encoder model
|
||||
elif self.model_type == "cross_encoder":
|
||||
response = self.engine.rerank(prompts)
|
||||
if not isinstance(response, list):
|
||||
response = [response]
|
||||
scores = [x["embedding"] for x in response]
|
||||
return ModelOutput(scores=scores)
|
||||
# reward model
|
||||
else:
|
||||
response = self.engine.encode(prompts)
|
||||
|
||||
@@ -41,6 +41,8 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
|
||||
# MLA test models
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
||||
DEFAULT_SMALL_CROSS_ENCODER_MODEL_NAME_FOR_TEST = "cross-encoder/ms-marco-MiniLM-L6-v2"
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test"
|
||||
|
||||
Reference in New Issue
Block a user