146 lines
5.6 KiB
Python
146 lines
5.6 KiB
Python
from torch.nn import functional as F
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from utils import prompt_template, truncate, hybrid_scores
|
|
|
|
class ERank_Transformer:
|
|
|
|
def __init__(self, model_name_or_path: str):
|
|
"""
|
|
Initializes the ERank_Transformer reranker.
|
|
|
|
Args:
|
|
model_name_or_path (str): The name or path of the model to be loaded.
|
|
This can be a Hugging Face model ID or a local path.
|
|
"""
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
|
self.reranker = AutoModelForCausalLM.from_pretrained(model_name_or_path).eval()
|
|
self.reranker.to("cuda")
|
|
|
|
def rerank(self, query: str, docs: list, instruction: str, truncate_length: int=None) -> list:
|
|
"""
|
|
Reranks a list of documents based on a query and a specific instruction.
|
|
|
|
Args:
|
|
query (str): The search query provided by the user.
|
|
docs (list): A list of dictionaries, where each dictionary represents a document
|
|
and must contain a "content" key.
|
|
instruction (str): The instruction for the model, guiding it on how to evaluate the documents.
|
|
truncate_length (int, optional): The maximum length to truncate the query and document content to. Defaults to None.
|
|
|
|
Returns:
|
|
list: A new list of document dictionaries, sorted by their "rank_score" in descending order.
|
|
"""
|
|
|
|
# prepare messages
|
|
messages = [
|
|
[{
|
|
"role": "user",
|
|
"content": prompt_template.format(
|
|
query=truncate(self.tokenizer, query, length=truncate_length) if truncate_length else query,
|
|
doc=truncate(self.tokenizer, doc["content"], length=truncate_length) if truncate_length else doc["content"],
|
|
instruction=instruction
|
|
)
|
|
}] for doc in docs
|
|
]
|
|
|
|
# encode tokens
|
|
texts = [
|
|
self.tokenizer.apply_chat_template(
|
|
each,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
) for each in messages
|
|
]
|
|
inputs = self.tokenizer(texts, padding=True, return_tensors="pt").to(self.reranker.device)
|
|
|
|
# LLM completion
|
|
outputs = self.reranker.generate(
|
|
**inputs,
|
|
max_new_tokens=8192,
|
|
output_scores=True,
|
|
return_dict_in_generate=True
|
|
)
|
|
|
|
# extract and organize results
|
|
results = []
|
|
scores = outputs.scores
|
|
generated_ids = outputs.sequences
|
|
answer_token_ids = self.tokenizer.encode("<answer>", add_special_tokens=False)
|
|
for idx in range(len(texts)):
|
|
|
|
# find <answer> in the generated sequence
|
|
output_ids = generated_ids[idx].tolist()
|
|
start_index = -1
|
|
for i in range(len(output_ids)-len(answer_token_ids)-1, -1, -1):
|
|
if output_ids[i:i + len(answer_token_ids)] == answer_token_ids:
|
|
start_index = i + len(answer_token_ids)
|
|
break
|
|
|
|
# start from the index after <answer>
|
|
answer = ""
|
|
prob = 1.0
|
|
if start_index != -1:
|
|
for t in range(start_index - inputs.input_ids.size(1), len(scores)):
|
|
generated_token_id = generated_ids[idx][inputs.input_ids.size(1) + t]
|
|
token = self.tokenizer.decode(generated_token_id)
|
|
if token.isdigit():
|
|
logits = scores[t][idx]
|
|
probs = F.softmax(logits, dim=-1)
|
|
prob *= probs[generated_token_id].item()
|
|
answer += token
|
|
else:
|
|
break
|
|
|
|
# in case the answer is not a digit or exceeds 10
|
|
try:
|
|
answer = int(answer)
|
|
assert answer <= 10
|
|
except:
|
|
answer = -1
|
|
|
|
# append to the final results
|
|
results.append({
|
|
**docs[idx],
|
|
"rank_score": answer * prob
|
|
})
|
|
|
|
# sort the reranking results for the query
|
|
results.sort(key=lambda x:x["rank_score"], reverse=True)
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# select a model
|
|
model_name_or_path = "Ucreate/ERank-4B"
|
|
# model_name_or_path = "Ucreate/ERank-14B"
|
|
# model_name_or_path = "Ucreate/ERank-32B"
|
|
reranker = ERank_Transformer(model_name_or_path)
|
|
|
|
# input data
|
|
instruction = "Retrieve relevant documents for the query."
|
|
query = "I am happy"
|
|
docs = [
|
|
{"content": "excited", "first_stage_score": 46.7},
|
|
{"content": "sad", "first_stage_score": 1.5},
|
|
{"content": "peaceful", "first_stage_score": 2.3},
|
|
]
|
|
|
|
# rerank
|
|
results = reranker.rerank(query, docs, instruction, truncate_length=2048)
|
|
print(results)
|
|
# [
|
|
# {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84},
|
|
# {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98}
|
|
# {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0},
|
|
# ]
|
|
|
|
# Optional: hybrid with first-stage scores
|
|
alpha = 0.2
|
|
hybrid_results = hybrid_scores(results, alpha)
|
|
print(hybrid_results)
|
|
# [
|
|
# {'content': 'excited', 'first_stage_score': 46.7, 'rank_score': 4.84, 'hybrid_score': 1.18},
|
|
# {'content': 'peaceful', 'first_stage_score': 2.3, 'rank_score': 2.98, 'hybrid_score':0.01},
|
|
# {'content': 'sad', 'first_stage_score': 1.5, 'rank_score': 0.0, 'hybrid_score': -1.19}
|
|
# ] |