初始化项目,由ModelHub XC社区提供模型
Model: ByteDance/ListConRanker Source: Original Platform
This commit is contained in:
287
modules/Reranking.py
Normal file
287
modules/Reranking.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
|
||||
# and associated documentation files (the “Software”), to deal in the Software without
|
||||
# restriction, including without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all copies or
|
||||
# substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
# OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
from mteb import RerankingEvaluator, AbsTaskReranking
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChineseRerankingEvaluator(RerankingEvaluator):
|
||||
"""
|
||||
This class evaluates a SentenceTransformer model for the task of re-ranking.
|
||||
Given a query and a list of documents, it computes the score [query, doc_i] for all possible
|
||||
documents and sorts them in decreasing order. Then, MRR@10 and MAP is compute to measure the quality of the ranking.
|
||||
:param samples: Must be a list and each element is of the form:
|
||||
- {'query': '', 'positive': [], 'negative': []}. Query is the search query, positive is a list of positive
|
||||
(relevant) documents, negative is a list of negative (irrelevant) documents.
|
||||
- {'query': [], 'positive': [], 'negative': []}. Where query is a list of strings, which embeddings we average
|
||||
to get the query embedding.
|
||||
"""
|
||||
|
||||
def __call__(self, model):
|
||||
scores = self.compute_metrics(model)
|
||||
return scores
|
||||
|
||||
def compute_metrics(self, model):
|
||||
return (
|
||||
self.compute_metrics_batched(model)
|
||||
if self.use_batched_encoding
|
||||
else self.compute_metrics_individual(model)
|
||||
)
|
||||
|
||||
def compute_metrics_batched(self, model):
|
||||
"""
|
||||
Computes the metrices in a batched way, by batching all queries and
|
||||
all documents together
|
||||
"""
|
||||
|
||||
if hasattr(model, 'compute_score'):
|
||||
return self.compute_metrics_batched_from_crossencoder(model)
|
||||
else:
|
||||
return self.compute_metrics_batched_from_biencoder(model)
|
||||
|
||||
def compute_metrics_batched_from_crossencoder(self, model):
|
||||
batch_size = 4
|
||||
|
||||
all_ap_scores = []
|
||||
all_mrr_1_scores = []
|
||||
all_mrr_5_scores = []
|
||||
all_mrr_10_scores = []
|
||||
|
||||
all_scores = []
|
||||
tmp_pairs = []
|
||||
for sample in tqdm(self.samples, desc="Evaluating"):
|
||||
b_pairs = [sample['query']]
|
||||
for p in sample['positive']:
|
||||
b_pairs.append(p)
|
||||
for n in sample['negative']:
|
||||
b_pairs.append(n)
|
||||
tmp_pairs.append(b_pairs)
|
||||
if len(tmp_pairs) == batch_size:
|
||||
sample_scores = model.compute_score(tmp_pairs)
|
||||
sample_scores = sum(sample_scores, [])
|
||||
all_scores += sample_scores
|
||||
tmp_pairs = []
|
||||
if len(tmp_pairs) > 0:
|
||||
sample_scores = model.compute_score(tmp_pairs)
|
||||
sample_scores = sum(sample_scores, [])
|
||||
all_scores += sample_scores
|
||||
all_scores = np.array(all_scores)
|
||||
|
||||
start_inx = 0
|
||||
for sample in tqdm(self.samples, desc="Evaluating"):
|
||||
is_relevant = [True] * len(sample['positive']) + [False] * len(sample['negative'])
|
||||
pred_scores = all_scores[start_inx:start_inx + len(is_relevant)]
|
||||
start_inx += len(is_relevant)
|
||||
pred_scores_argsort = np.argsort(-pred_scores) # Sort in decreasing order
|
||||
|
||||
ap = self.ap_score(is_relevant, pred_scores)
|
||||
|
||||
mrr_1 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 1)
|
||||
mrr_5 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 5)
|
||||
mrr_10 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 10)
|
||||
|
||||
all_mrr_1_scores.append(mrr_1)
|
||||
all_mrr_5_scores.append(mrr_5)
|
||||
all_mrr_10_scores.append(mrr_10)
|
||||
all_ap_scores.append(ap)
|
||||
|
||||
mean_ap = np.mean(all_ap_scores)
|
||||
mean_mrr_1 = np.mean(all_mrr_1_scores)
|
||||
mean_mrr_5 = np.mean(all_mrr_5_scores)
|
||||
mean_mrr_10 = np.mean(all_mrr_10_scores)
|
||||
|
||||
return {"map": mean_ap, "mrr_1": mean_mrr_1, 'mrr_5': mean_mrr_5, 'mrr_10': mean_mrr_10}
|
||||
|
||||
def compute_metrics_batched_from_biencoder(self, model):
|
||||
all_mrr_scores = []
|
||||
all_ap_scores = []
|
||||
logger.info("Encoding queries...")
|
||||
if isinstance(self.samples[0]["query"], str):
|
||||
if hasattr(model, 'encode_queries'):
|
||||
all_query_embs = model.encode_queries(
|
||||
[sample["query"] for sample in self.samples],
|
||||
convert_to_tensor=True,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
else:
|
||||
all_query_embs = model.encode(
|
||||
[sample["query"] for sample in self.samples],
|
||||
convert_to_tensor=True,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
elif isinstance(self.samples[0]["query"], list):
|
||||
# In case the query is a list of strings, we get the most similar embedding to any of the queries
|
||||
all_query_flattened = [q for sample in self.samples for q in sample["query"]]
|
||||
if hasattr(model, 'encode_queries'):
|
||||
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
|
||||
batch_size=self.batch_size)
|
||||
else:
|
||||
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
|
||||
else:
|
||||
raise ValueError(f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}")
|
||||
|
||||
logger.info("Encoding candidates...")
|
||||
all_docs = []
|
||||
for sample in self.samples:
|
||||
all_docs.extend(sample["positive"])
|
||||
all_docs.extend(sample["negative"])
|
||||
|
||||
all_docs_embs = model.encode(all_docs, convert_to_tensor=True, batch_size=self.batch_size)
|
||||
|
||||
# Compute scores
|
||||
logger.info("Evaluating...")
|
||||
query_idx, docs_idx = 0, 0
|
||||
for instance in self.samples:
|
||||
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
|
||||
query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
|
||||
query_idx += num_subqueries
|
||||
|
||||
num_pos = len(instance["positive"])
|
||||
num_neg = len(instance["negative"])
|
||||
docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
|
||||
docs_idx += num_pos + num_neg
|
||||
|
||||
if num_pos == 0 or num_neg == 0:
|
||||
continue
|
||||
|
||||
is_relevant = [True] * num_pos + [False] * num_neg
|
||||
|
||||
scores = self._compute_metrics_instance(query_emb, docs_emb, is_relevant)
|
||||
all_mrr_scores.append(scores["mrr"])
|
||||
all_ap_scores.append(scores["ap"])
|
||||
|
||||
mean_ap = np.mean(all_ap_scores)
|
||||
mean_mrr = np.mean(all_mrr_scores)
|
||||
|
||||
return {"map": mean_ap, "mrr": mean_mrr}
|
||||
|
||||
|
||||
def evaluate(self, model, split="test", **kwargs):
|
||||
if not self.data_loaded:
|
||||
self.load_data()
|
||||
|
||||
data_split = self.dataset[split]
|
||||
|
||||
evaluator = ChineseRerankingEvaluator(data_split, **kwargs)
|
||||
scores = evaluator(model)
|
||||
|
||||
return dict(scores)
|
||||
|
||||
|
||||
AbsTaskReranking.evaluate = evaluate
|
||||
|
||||
|
||||
class T2Reranking(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'T2Reranking',
|
||||
'hf_hub_name': "C-MTEB/T2Reranking",
|
||||
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
|
||||
"reference": "https://arxiv.org/abs/2304.03679",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class T2RerankingZh2En(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'T2RerankingZh2En',
|
||||
'hf_hub_name': "C-MTEB/T2Reranking_zh2en",
|
||||
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
|
||||
"reference": "https://arxiv.org/abs/2304.03679",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['zh2en'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class T2RerankingEn2Zh(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'T2RerankingEn2Zh',
|
||||
'hf_hub_name': "C-MTEB/T2Reranking_en2zh",
|
||||
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
|
||||
"reference": "https://arxiv.org/abs/2304.03679",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['en2zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class MMarcoReranking(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'MMarcoReranking',
|
||||
'hf_hub_name': "C-MTEB/Mmarco-reranking",
|
||||
'description': 'mMARCO is a multilingual version of the MS MARCO passage ranking dataset',
|
||||
"reference": "https://github.com/unicamp-dl/mMARCO",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class CMedQAv1(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'CMedQAv1',
|
||||
"hf_hub_name": "C-MTEB/CMedQAv1-reranking",
|
||||
'description': 'Chinese community medical question answering',
|
||||
"reference": "https://github.com/zhangsheng93/cMedQA",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['test'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class CMedQAv2(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'CMedQAv2',
|
||||
"hf_hub_name": "C-MTEB/CMedQAv2-reranking",
|
||||
'description': 'Chinese community medical question answering',
|
||||
"reference": "https://github.com/zhangsheng93/cMedQA2",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['test'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
325
modules/Reranking_loop.py
Normal file
325
modules/Reranking_loop.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
|
||||
# and associated documentation files (the “Software”), to deal in the Software without
|
||||
# restriction, including without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all copies or
|
||||
# substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
# OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
from mteb import RerankingEvaluator, AbsTaskReranking
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChineseRerankingEvaluator(RerankingEvaluator):
|
||||
"""
|
||||
This class evaluates a SentenceTransformer model for the task of re-ranking.
|
||||
Given a query and a list of documents, it computes the score [query, doc_i] for all possible
|
||||
documents and sorts them in decreasing order. Then, MRR@10 and MAP is compute to measure the quality of the ranking.
|
||||
:param samples: Must be a list and each element is of the form:
|
||||
- {'query': '', 'positive': [], 'negative': []}. Query is the search query, positive is a list of positive
|
||||
(relevant) documents, negative is a list of negative (irrelevant) documents.
|
||||
- {'query': [], 'positive': [], 'negative': []}. Where query is a list of strings, which embeddings we average
|
||||
to get the query embedding.
|
||||
"""
|
||||
|
||||
def __call__(self, model):
|
||||
scores = self.compute_metrics(model)
|
||||
return scores
|
||||
|
||||
def compute_metrics(self, model):
|
||||
return (
|
||||
self.compute_metrics_batched(model)
|
||||
if self.use_batched_encoding
|
||||
else self.compute_metrics_individual(model)
|
||||
)
|
||||
|
||||
def compute_metrics_batched(self, model):
|
||||
"""
|
||||
Computes the metrices in a batched way, by batching all queries and
|
||||
all documents together
|
||||
"""
|
||||
|
||||
if hasattr(model, 'compute_score'):
|
||||
return self.compute_metrics_batched_from_crossencoder(model)
|
||||
else:
|
||||
return self.compute_metrics_batched_from_biencoder(model)
|
||||
|
||||
def compute_metrics_batched_from_crossencoder(self, model):
|
||||
all_ap_scores = []
|
||||
all_mrr_1_scores = []
|
||||
all_mrr_5_scores = []
|
||||
all_mrr_10_scores = []
|
||||
|
||||
for sample in tqdm(self.samples, desc="Evaluating"):
|
||||
query = sample['query']
|
||||
pos = sample['positive']
|
||||
neg = sample['negative']
|
||||
passage = pos + neg
|
||||
passage2label = {}
|
||||
for p in pos:
|
||||
passage2label[p] = True
|
||||
for p in neg:
|
||||
passage2label[p] = False
|
||||
|
||||
filter_times = 0
|
||||
passage2score = {}
|
||||
while len(passage) > 20:
|
||||
batch = [[query] + passage]
|
||||
pred_scores = model.compute_score(batch)[0]
|
||||
# Sort in increasing order
|
||||
pred_scores_argsort = np.argsort(pred_scores).tolist()
|
||||
passage_len = len(passage)
|
||||
to_filter_num = math.ceil(passage_len * 0.2)
|
||||
if to_filter_num < 10:
|
||||
to_filter_num = 10
|
||||
|
||||
have_filter_num = 0
|
||||
while have_filter_num < to_filter_num:
|
||||
idx = pred_scores_argsort[have_filter_num]
|
||||
if passage[idx] in passage2score:
|
||||
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
||||
else:
|
||||
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
||||
have_filter_num += 1
|
||||
while pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]]:
|
||||
idx = pred_scores_argsort[have_filter_num]
|
||||
if passage[idx] in passage2score:
|
||||
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
||||
else:
|
||||
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
||||
have_filter_num += 1
|
||||
next_passage = []
|
||||
next_passage_idx = have_filter_num
|
||||
while next_passage_idx < len(passage):
|
||||
idx = pred_scores_argsort[next_passage_idx]
|
||||
next_passage.append(passage[idx])
|
||||
next_passage_idx += 1
|
||||
passage = next_passage
|
||||
filter_times += 1
|
||||
|
||||
batch = [[query] + passage]
|
||||
pred_scores = model.compute_score(batch)[0]
|
||||
cnt = 0
|
||||
while cnt < len(passage):
|
||||
if passage[cnt] in passage2score:
|
||||
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
||||
else:
|
||||
passage2score[passage[cnt]] = [pred_scores[cnt] + filter_times]
|
||||
cnt += 1
|
||||
|
||||
passage = list(set(pos + neg))
|
||||
is_relevant = []
|
||||
final_score = []
|
||||
for i in range(len(passage)):
|
||||
p = passage[i]
|
||||
is_relevant += [passage2label[p]] * len(passage2score[p])
|
||||
final_score += passage2score[p]
|
||||
|
||||
ap = self.ap_score(is_relevant, final_score)
|
||||
|
||||
pred_scores_argsort = np.argsort(-(np.array(final_score)))
|
||||
mrr_1 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 1)
|
||||
mrr_5 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 5)
|
||||
mrr_10 = self.mrr_at_k_score(is_relevant, pred_scores_argsort, 10)
|
||||
|
||||
all_ap_scores.append(ap)
|
||||
all_mrr_1_scores.append(mrr_1)
|
||||
all_mrr_5_scores.append(mrr_5)
|
||||
all_mrr_10_scores.append(mrr_10)
|
||||
|
||||
mean_ap = np.mean(all_ap_scores)
|
||||
mean_mrr_1 = np.mean(all_mrr_1_scores)
|
||||
mean_mrr_5 = np.mean(all_mrr_5_scores)
|
||||
mean_mrr_10 = np.mean(all_mrr_10_scores)
|
||||
|
||||
return {"map": mean_ap, "mrr_1": mean_mrr_1, 'mrr_5': mean_mrr_5, 'mrr_10': mean_mrr_10}
|
||||
|
||||
def compute_metrics_batched_from_biencoder(self, model):
|
||||
all_mrr_scores = []
|
||||
all_ap_scores = []
|
||||
logger.info("Encoding queries...")
|
||||
if isinstance(self.samples[0]["query"], str):
|
||||
if hasattr(model, 'encode_queries'):
|
||||
all_query_embs = model.encode_queries(
|
||||
[sample["query"] for sample in self.samples],
|
||||
convert_to_tensor=True,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
else:
|
||||
all_query_embs = model.encode(
|
||||
[sample["query"] for sample in self.samples],
|
||||
convert_to_tensor=True,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
elif isinstance(self.samples[0]["query"], list):
|
||||
# In case the query is a list of strings, we get the most similar embedding to any of the queries
|
||||
all_query_flattened = [q for sample in self.samples for q in sample["query"]]
|
||||
if hasattr(model, 'encode_queries'):
|
||||
all_query_embs = model.encode_queries(all_query_flattened, convert_to_tensor=True,
|
||||
batch_size=self.batch_size)
|
||||
else:
|
||||
all_query_embs = model.encode(all_query_flattened, convert_to_tensor=True, batch_size=self.batch_size)
|
||||
else:
|
||||
raise ValueError(f"Query must be a string or a list of strings but is {type(self.samples[0]['query'])}")
|
||||
|
||||
logger.info("Encoding candidates...")
|
||||
all_docs = []
|
||||
for sample in self.samples:
|
||||
all_docs.extend(sample["positive"])
|
||||
all_docs.extend(sample["negative"])
|
||||
|
||||
all_docs_embs = model.encode(all_docs, convert_to_tensor=True, batch_size=self.batch_size)
|
||||
|
||||
# Compute scores
|
||||
logger.info("Evaluating...")
|
||||
query_idx, docs_idx = 0, 0
|
||||
for instance in self.samples:
|
||||
num_subqueries = len(instance["query"]) if isinstance(instance["query"], list) else 1
|
||||
query_emb = all_query_embs[query_idx: query_idx + num_subqueries]
|
||||
query_idx += num_subqueries
|
||||
|
||||
num_pos = len(instance["positive"])
|
||||
num_neg = len(instance["negative"])
|
||||
docs_emb = all_docs_embs[docs_idx: docs_idx + num_pos + num_neg]
|
||||
docs_idx += num_pos + num_neg
|
||||
|
||||
if num_pos == 0 or num_neg == 0:
|
||||
continue
|
||||
|
||||
is_relevant = [True] * num_pos + [False] * num_neg
|
||||
|
||||
scores = self._compute_metrics_instance(query_emb, docs_emb, is_relevant)
|
||||
all_mrr_scores.append(scores["mrr"])
|
||||
all_ap_scores.append(scores["ap"])
|
||||
|
||||
mean_ap = np.mean(all_ap_scores)
|
||||
mean_mrr = np.mean(all_mrr_scores)
|
||||
|
||||
return {"map": mean_ap, "mrr": mean_mrr}
|
||||
|
||||
|
||||
def evaluate(self, model, split="test", **kwargs):
|
||||
if not self.data_loaded:
|
||||
self.load_data()
|
||||
|
||||
data_split = self.dataset[split]
|
||||
|
||||
evaluator = ChineseRerankingEvaluator(data_split, **kwargs)
|
||||
scores = evaluator(model)
|
||||
|
||||
return dict(scores)
|
||||
|
||||
|
||||
AbsTaskReranking.evaluate = evaluate
|
||||
|
||||
|
||||
class T2Reranking(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'T2Reranking',
|
||||
'hf_hub_name': "C-MTEB/T2Reranking",
|
||||
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
|
||||
"reference": "https://arxiv.org/abs/2304.03679",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class T2RerankingZh2En(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'T2RerankingZh2En',
|
||||
'hf_hub_name': "C-MTEB/T2Reranking_zh2en",
|
||||
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
|
||||
"reference": "https://arxiv.org/abs/2304.03679",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['zh2en'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class T2RerankingEn2Zh(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'T2RerankingEn2Zh',
|
||||
'hf_hub_name': "C-MTEB/T2Reranking_en2zh",
|
||||
'description': 'T2Ranking: A large-scale Chinese Benchmark for Passage Ranking',
|
||||
"reference": "https://arxiv.org/abs/2304.03679",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['en2zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class MMarcoReranking(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'MMarcoReranking',
|
||||
'hf_hub_name': "C-MTEB/Mmarco-reranking",
|
||||
'description': 'mMARCO is a multilingual version of the MS MARCO passage ranking dataset',
|
||||
"reference": "https://github.com/unicamp-dl/mMARCO",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['dev'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class CMedQAv1(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'CMedQAv1',
|
||||
"hf_hub_name": "C-MTEB/CMedQAv1-reranking",
|
||||
'description': 'Chinese community medical question answering',
|
||||
"reference": "https://github.com/zhangsheng93/cMedQA",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['test'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
|
||||
|
||||
class CMedQAv2(AbsTaskReranking):
|
||||
@property
|
||||
def description(self):
|
||||
return {
|
||||
'name': 'CMedQAv2',
|
||||
"hf_hub_name": "C-MTEB/CMedQAv2-reranking",
|
||||
'description': 'Chinese community medical question answering',
|
||||
"reference": "https://github.com/zhangsheng93/cMedQA2",
|
||||
'type': 'Reranking',
|
||||
'category': 's2p',
|
||||
'eval_splits': ['test'],
|
||||
'eval_langs': ['zh'],
|
||||
'main_score': 'map',
|
||||
}
|
||||
161
modules/listconranker.py
Normal file
161
modules/listconranker.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
|
||||
# and associated documentation files (the “Software”), to deal in the Software without
|
||||
# restriction, including without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all copies or
|
||||
# substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
# OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import AutoTokenizer, is_torch_npu_available
|
||||
from typing import Union, List
|
||||
from .modeling import CrossEncoder
|
||||
|
||||
import os
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
class ListConRanker:
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: str = None,
|
||||
use_fp16: bool = False,
|
||||
cache_dir: str = None,
|
||||
device: Union[str, int] = None,
|
||||
list_transformer_layer = None
|
||||
) -> None:
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
||||
self.model = CrossEncoder.from_pretrained_for_eval(model_name_or_path, list_transformer_layer)
|
||||
|
||||
if device and isinstance(device, str):
|
||||
self.device = torch.device(device)
|
||||
if device == 'cpu':
|
||||
use_fp16 = False
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
if device is not None:
|
||||
self.device = torch.device(f"cuda:{device}")
|
||||
else:
|
||||
self.device = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
self.device = torch.device("mps")
|
||||
elif is_torch_npu_available():
|
||||
self.device = torch.device("npu")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
use_fp16 = False
|
||||
if use_fp16:
|
||||
self.model.half()
|
||||
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
if device is None:
|
||||
self.num_gpus = torch.cuda.device_count()
|
||||
if self.num_gpus > 1:
|
||||
print(f"----------using {self.num_gpus}*GPUs----------")
|
||||
self.model = torch.nn.DataParallel(self.model)
|
||||
else:
|
||||
self.num_gpus = 1
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_score(self, sentence_pairs: List[List[str]], max_length: int = 512) -> List[List[float]]:
|
||||
pair_nums = [len(pairs) - 1 for pairs in sentence_pairs]
|
||||
sentences_batch = sum(sentence_pairs, [])
|
||||
inputs = self.tokenizer(
|
||||
sentences_batch,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
max_length=max_length,
|
||||
).to(self.device)
|
||||
inputs['pair_num'] = torch.LongTensor(pair_nums)
|
||||
scores = self.model(inputs).float()
|
||||
all_scores = scores.cpu().numpy().tolist()
|
||||
|
||||
if isinstance(all_scores, float):
|
||||
return [all_scores]
|
||||
result = []
|
||||
curr_idx = 0
|
||||
for i in range(len(pair_nums)):
|
||||
result.append(all_scores[curr_idx: curr_idx + pair_nums[i]])
|
||||
curr_idx += pair_nums[i]
|
||||
# return all_scores
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def iterative_inference(self, sentence_pairs: List[str], max_length: int = 512) -> List[float]:
|
||||
query = sentence_pairs[0]
|
||||
passage = sentence_pairs[1:]
|
||||
|
||||
filter_times = 0
|
||||
passage2score = {}
|
||||
while len(passage) > 20:
|
||||
batch = [[query] + passage]
|
||||
pred_scores = self.compute_score(batch, max_length)[0]
|
||||
# Sort in increasing order
|
||||
pred_scores_argsort = np.argsort(pred_scores).tolist()
|
||||
passage_len = len(passage)
|
||||
to_filter_num = math.ceil(passage_len * 0.2)
|
||||
if to_filter_num < 10:
|
||||
to_filter_num = 10
|
||||
|
||||
have_filter_num = 0
|
||||
while have_filter_num < to_filter_num:
|
||||
idx = pred_scores_argsort[have_filter_num]
|
||||
if passage[idx] in passage2score:
|
||||
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
||||
else:
|
||||
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
||||
have_filter_num += 1
|
||||
while pred_scores[pred_scores_argsort[have_filter_num - 1]] == pred_scores[pred_scores_argsort[have_filter_num]]:
|
||||
idx = pred_scores_argsort[have_filter_num]
|
||||
if passage[idx] in passage2score:
|
||||
passage2score[passage[idx]].append(pred_scores[idx] + filter_times)
|
||||
else:
|
||||
passage2score[passage[idx]] = [pred_scores[idx] + filter_times]
|
||||
have_filter_num += 1
|
||||
next_passage = []
|
||||
next_passage_idx = have_filter_num
|
||||
while next_passage_idx < len(passage):
|
||||
idx = pred_scores_argsort[next_passage_idx]
|
||||
next_passage.append(passage[idx])
|
||||
next_passage_idx += 1
|
||||
passage = next_passage
|
||||
filter_times += 1
|
||||
|
||||
batch = [[query] + passage]
|
||||
pred_scores = self.compute_score(batch, max_length)[0]
|
||||
cnt = 0
|
||||
while cnt < len(passage):
|
||||
if passage[cnt] in passage2score:
|
||||
passage2score[passage[cnt]].append(pred_scores[cnt] + filter_times)
|
||||
else:
|
||||
passage2score[passage[cnt]] = [pred_scores[cnt] + filter_times]
|
||||
cnt += 1
|
||||
|
||||
passage = sentence_pairs[1:]
|
||||
final_score = []
|
||||
for i in range(len(passage)):
|
||||
p = passage[i]
|
||||
final_score += passage2score[p]
|
||||
return final_score
|
||||
174
modules/modeling.py
Normal file
174
modules/modeling.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
|
||||
# and associated documentation files (the “Software”), to deal in the Software without
|
||||
# restriction, including without limitation the rights to use, copy, modify, merge, publish,
|
||||
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all copies or
|
||||
# substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
# OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
# OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, PreTrainedModel
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ListTransformer(nn.Module):
|
||||
def __init__(self, num_layer, config, device) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.list_transformer_layer = nn.TransformerEncoderLayer(1792, self.config.num_attention_heads, batch_first=True, activation=F.gelu, norm_first=False)
|
||||
self.list_transformer = nn.TransformerEncoder(self.list_transformer_layer, num_layer)
|
||||
self.relu = nn.ReLU()
|
||||
self.query_embedding = QueryEmbedding(config, device)
|
||||
|
||||
self.linear_score3 = nn.Linear(1792 * 2, 1792)
|
||||
self.linear_score2 = nn.Linear(1792 * 2, 1792)
|
||||
self.linear_score1 = nn.Linear(1792 * 2, 1)
|
||||
|
||||
def forward(self, pair_features, pair_nums):
|
||||
pair_nums = [x + 1 for x in pair_nums]
|
||||
batch_pair_features = pair_features.split(pair_nums)
|
||||
|
||||
pair_feature_query_passage_concat_list = []
|
||||
for i in range(len(batch_pair_features)):
|
||||
pair_feature_query = batch_pair_features[i][0].unsqueeze(0).repeat(pair_nums[i] - 1, 1)
|
||||
pair_feature_passage = batch_pair_features[i][1:]
|
||||
pair_feature_query_passage_concat_list.append(torch.cat([pair_feature_query, pair_feature_passage], dim=1))
|
||||
pair_feature_query_passage_concat = torch.cat(pair_feature_query_passage_concat_list, dim=0)
|
||||
|
||||
batch_pair_features = nn.utils.rnn.pad_sequence(batch_pair_features, batch_first=True)
|
||||
|
||||
query_embedding_tags = torch.zeros(batch_pair_features.size(0), batch_pair_features.size(1), dtype=torch.long, device=self.device)
|
||||
query_embedding_tags[:, 0] = 1
|
||||
batch_pair_features = self.query_embedding(batch_pair_features, query_embedding_tags)
|
||||
|
||||
mask = self.generate_attention_mask(pair_nums)
|
||||
query_mask = self.generate_attention_mask_custom(pair_nums)
|
||||
pair_list_features = self.list_transformer(batch_pair_features, src_key_padding_mask=mask, mask=query_mask)
|
||||
|
||||
output_pair_list_features = []
|
||||
output_query_list_features = []
|
||||
pair_features_after_transformer_list = []
|
||||
for idx, pair_num in enumerate(pair_nums):
|
||||
output_pair_list_features.append(pair_list_features[idx, 1:pair_num, :])
|
||||
output_query_list_features.append(pair_list_features[idx, 0, :])
|
||||
pair_features_after_transformer_list.append(pair_list_features[idx, :pair_num, :])
|
||||
|
||||
pair_features_after_transformer_cat_query_list = []
|
||||
for idx, pair_num in enumerate(pair_nums):
|
||||
query_ft = output_query_list_features[idx].unsqueeze(0).repeat(pair_num - 1, 1)
|
||||
pair_features_after_transformer_cat_query = torch.cat([query_ft, output_pair_list_features[idx]], dim=1)
|
||||
pair_features_after_transformer_cat_query_list.append(pair_features_after_transformer_cat_query)
|
||||
pair_features_after_transformer_cat_query = torch.cat(pair_features_after_transformer_cat_query_list, dim=0)
|
||||
|
||||
pair_feature_query_passage_concat = self.relu(self.linear_score2(pair_feature_query_passage_concat))
|
||||
pair_features_after_transformer_cat_query = self.relu(self.linear_score3(pair_features_after_transformer_cat_query))
|
||||
final_ft = torch.cat([pair_feature_query_passage_concat, pair_features_after_transformer_cat_query], dim=1)
|
||||
logits = self.linear_score1(final_ft).squeeze()
|
||||
|
||||
return logits, torch.cat(pair_features_after_transformer_list, dim=0)
|
||||
|
||||
def generate_attention_mask(self, pair_num):
|
||||
max_len = max(pair_num)
|
||||
batch_size = len(pair_num)
|
||||
mask = torch.zeros(batch_size, max_len, dtype=torch.bool, device=self.device)
|
||||
for i, length in enumerate(pair_num):
|
||||
mask[i, length:] = True
|
||||
return mask
|
||||
|
||||
def generate_attention_mask_custom(self, pair_num):
|
||||
max_len = max(pair_num)
|
||||
|
||||
mask = torch.zeros(max_len, max_len, dtype=torch.bool, device=self.device)
|
||||
mask[0, 1:] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
class QueryEmbedding(nn.Module):
|
||||
def __init__(self, config, device) -> None:
|
||||
super().__init__()
|
||||
self.query_embedding = nn.Embedding(2, 1792)
|
||||
self.layerNorm = nn.LayerNorm(1792)
|
||||
|
||||
def forward(self, x, tags):
|
||||
query_embeddings = self.query_embedding(tags)
|
||||
x += query_embeddings
|
||||
x = self.layerNorm(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossEncoder(nn.Module):
|
||||
def __init__(self, hf_model: PreTrainedModel, list_transformer_layer_4eval: int=None):
|
||||
super().__init__()
|
||||
self.hf_model = hf_model
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
self.config = self.hf_model.config
|
||||
self.config.output_hidden_states = True
|
||||
|
||||
self.linear_in_embedding = nn.Linear(1024, 1792)
|
||||
self.list_transformer_layer = list_transformer_layer_4eval
|
||||
self.list_transformer = ListTransformer(self.list_transformer_layer, self.config, self.device)
|
||||
|
||||
def forward(self, batch):
|
||||
if 'pair_num' in batch:
|
||||
pair_nums = batch.pop('pair_num').tolist()
|
||||
|
||||
if self.training:
|
||||
pass
|
||||
else:
|
||||
split_batch = 400
|
||||
input_ids = batch['input_ids']
|
||||
attention_mask = batch['attention_mask']
|
||||
if sum(pair_nums) > split_batch:
|
||||
last_hidden_state_list = []
|
||||
input_ids_list = input_ids.split(split_batch)
|
||||
attention_mask_list = attention_mask.split(split_batch)
|
||||
for i in range(len(input_ids_list)):
|
||||
last_hidden_state = self.hf_model(input_ids=input_ids_list[i], attention_mask=attention_mask_list[i], return_dict=True).hidden_states[-1]
|
||||
last_hidden_state_list.append(last_hidden_state)
|
||||
last_hidden_state = torch.cat(last_hidden_state_list, dim=0)
|
||||
else:
|
||||
ranker_out = self.hf_model(**batch, return_dict=True)
|
||||
last_hidden_state = ranker_out.last_hidden_state
|
||||
|
||||
pair_features = self.average_pooling(last_hidden_state, attention_mask)
|
||||
pair_features = self.linear_in_embedding(pair_features)
|
||||
|
||||
logits, pair_features_after_list_transformer = self.list_transformer(pair_features, pair_nums)
|
||||
logits = self.sigmoid(logits)
|
||||
|
||||
return logits
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_for_eval(cls, model_name_or_path, list_transformer_layer):
|
||||
hf_model = AutoModel.from_pretrained(model_name_or_path)
|
||||
reranker = cls(hf_model, list_transformer_layer)
|
||||
reranker.linear_in_embedding.load_state_dict(torch.load(model_name_or_path + '/linear_in_embedding.pt'))
|
||||
reranker.list_transformer.load_state_dict(torch.load(model_name_or_path + '/list_transformer.pt'))
|
||||
return reranker
|
||||
|
||||
def average_pooling(self, hidden_state, attention_mask):
|
||||
extended_attention_mask = attention_mask.unsqueeze(-1).expand(hidden_state.size()).to(dtype=hidden_state.dtype)
|
||||
masked_hidden_state = hidden_state * extended_attention_mask
|
||||
sum_embeddings = torch.sum(masked_hidden_state, dim=1)
|
||||
sum_mask = extended_attention_mask.sum(dim=1)
|
||||
return sum_embeddings / sum_mask
|
||||
Reference in New Issue
Block a user