193 lines
6.4 KiB
Python
193 lines
6.4 KiB
Python
from torch.utils.data import DataLoader
|
|
from sentence_transformers import LoggingHandler
|
|
from sentence_transformers.cross_encoder import CrossEncoder
|
|
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
|
|
from sentence_transformers import InputExample
|
|
import logging
|
|
from datetime import datetime
|
|
import gzip
|
|
import sys
|
|
import numpy as np
|
|
import os
|
|
from shutil import copyfile
|
|
import csv
|
|
import tqdm
|
|
|
|
#### Just some code to print debug information to stdout
|
|
logging.basicConfig(format='%(asctime)s - %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S',
|
|
level=logging.INFO,
|
|
handlers=[LoggingHandler()])
|
|
#### /print debug information to stdout
|
|
|
|
|
|
#Define our Cross-Encoder
|
|
model_name = sys.argv[1] #'google/electra-small-discriminator'
|
|
train_batch_size = 32
|
|
num_epochs = 1
|
|
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
|
|
|
#We set num_labels=1, which predicts a continous score between 0 and 1
|
|
model = CrossEncoder(model_name, num_labels=1, max_length=512)
|
|
|
|
|
|
# Write self to path
|
|
os.makedirs(model_save_path, exist_ok=True)
|
|
|
|
train_script_path = os.path.join(model_save_path, 'train_script.py')
|
|
copyfile(__file__, train_script_path)
|
|
with open(train_script_path, 'a') as fOut:
|
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
|
|
|
|
|
corpus = {}
|
|
queries = {}
|
|
|
|
#### Read train file
|
|
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
|
|
for line in fIn:
|
|
pid, passage = line.strip().split("\t")
|
|
corpus[pid] = passage
|
|
|
|
with open('../data/queries.train.tsv', 'r') as fIn:
|
|
for line in fIn:
|
|
qid, query = line.strip().split("\t")
|
|
queries[qid] = query
|
|
|
|
|
|
|
|
pos_neg_ration = (4+1)
|
|
cnt = 0
|
|
train_samples = []
|
|
dev_samples = {}
|
|
|
|
num_dev_queries = 125
|
|
num_max_dev_negatives = 200
|
|
|
|
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
|
|
for line in fIn:
|
|
qid, pos_id, neg_id = line.strip().split()
|
|
|
|
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
|
|
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
|
|
|
|
if qid in dev_samples:
|
|
dev_samples[qid]['positive'].add(corpus[pos_id])
|
|
|
|
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
|
|
dev_samples[qid]['negative'].add(corpus[neg_id])
|
|
|
|
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
|
|
for line in tqdm.tqdm(fIn, unit_scale=True):
|
|
cnt += 1
|
|
qid, pos_id, neg_id = line.strip().split()
|
|
query = queries[qid]
|
|
if (cnt % pos_neg_ration) == 0:
|
|
passage = corpus[pos_id]
|
|
label = 1
|
|
else:
|
|
passage = corpus[neg_id]
|
|
label = 0
|
|
|
|
train_samples.append(InputExample(texts=[query, passage], label=label))
|
|
|
|
if len(train_samples) >= 2e7:
|
|
break
|
|
|
|
|
|
|
|
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
|
|
|
|
# We add an evaluator, which evaluates the performance during training
|
|
|
|
class CERerankingEvaluator:
|
|
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
|
|
self.samples = samples
|
|
self.name = name
|
|
self.mrr_at_k = mrr_at_k
|
|
|
|
if isinstance(self.samples, dict):
|
|
self.samples = list(self.samples.values())
|
|
|
|
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
|
|
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
|
|
|
|
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
|
|
if epoch != -1:
|
|
if steps == -1:
|
|
out_txt = " after epoch {}:".format(epoch)
|
|
else:
|
|
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
|
|
else:
|
|
out_txt = ":"
|
|
|
|
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
|
|
|
|
all_mrr_scores = []
|
|
num_queries = 0
|
|
num_positives = []
|
|
num_negatives = []
|
|
for instance in self.samples:
|
|
query = instance['query']
|
|
positive = list(instance['positive'])
|
|
negative = list(instance['negative'])
|
|
docs = positive + negative
|
|
is_relevant = [True]*len(positive) + [False]*len(negative)
|
|
|
|
if len(positive) == 0 or len(negative) == 0:
|
|
continue
|
|
|
|
num_queries += 1
|
|
num_positives.append(len(positive))
|
|
num_negatives.append(len(negative))
|
|
|
|
model_input = [[query, doc] for doc in docs]
|
|
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
|
|
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
|
|
|
|
mrr_score = 0
|
|
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
|
|
if is_relevant[index]:
|
|
mrr_score = 1 / (rank+1)
|
|
|
|
all_mrr_scores.append(mrr_score)
|
|
|
|
mean_mrr = np.mean(all_mrr_scores)
|
|
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
|
|
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
|
|
|
|
if output_path is not None:
|
|
csv_path = os.path.join(output_path, self.csv_file)
|
|
output_file_exists = os.path.isfile(csv_path)
|
|
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
|
|
writer = csv.writer(f)
|
|
if not output_file_exists:
|
|
writer.writerow(self.csv_headers)
|
|
|
|
writer.writerow([epoch, steps, mean_mrr])
|
|
|
|
return mean_mrr
|
|
|
|
|
|
evaluator = CERerankingEvaluator(dev_samples)
|
|
|
|
# Configure the training
|
|
warmup_steps = 5000
|
|
logging.info("Warmup-steps: {}".format(warmup_steps))
|
|
|
|
|
|
# Train the model
|
|
model.fit(train_dataloader=train_dataloader,
|
|
evaluator=evaluator,
|
|
epochs=num_epochs,
|
|
evaluation_steps=5000,
|
|
warmup_steps=warmup_steps,
|
|
output_path=model_save_path,
|
|
use_amp=True)
|
|
|
|
#Save latest model
|
|
model.save(model_save_path+'-latest')
|
|
|
|
|
|
# Script was called via:
|
|
#python train_cross-encoder.py nreimers/TinyBERT_L-6_H-768_v2 |