273 lines
9.0 KiB
Python
Executable File
273 lines
9.0 KiB
Python
Executable File
import gzip
|
|
import random
|
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW
|
|
import sys
|
|
import torch
|
|
import transformers
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.cuda.amp import autocast
|
|
import tqdm
|
|
from datetime import datetime
|
|
from shutil import copyfile
|
|
import os
|
|
####################################
|
|
|
|
import gzip
|
|
from collections import defaultdict
|
|
import logging
|
|
import tqdm
|
|
import numpy as np
|
|
import sys
|
|
import pytrec_eval
|
|
from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
|
import torch
|
|
|
|
|
|
model_name = sys.argv[1]
|
|
max_length = 350
|
|
|
|
######### Evaluation
|
|
queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz'
|
|
queries_eval = {}
|
|
with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
|
|
for line in fIn:
|
|
qid, query = line.strip().split("\t")[0:2]
|
|
queries_eval[qid] = query
|
|
|
|
rel = defaultdict(lambda: defaultdict(int))
|
|
|
|
with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn:
|
|
for line in fIn:
|
|
qid, _, pid, score = line.strip().split()
|
|
score = int(score)
|
|
if score > 0:
|
|
rel[qid][pid] = score
|
|
|
|
relevant_qid = []
|
|
for qid in queries_eval:
|
|
if len(rel[qid]) > 0:
|
|
relevant_qid.append(qid)
|
|
|
|
# Read top 1k
|
|
passage_cand = {}
|
|
|
|
with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn:
|
|
for line in fIn:
|
|
qid, pid, query, passage = line.strip().split("\t")
|
|
if qid not in passage_cand:
|
|
passage_cand[qid] = []
|
|
|
|
passage_cand[qid].append([pid, passage])
|
|
|
|
|
|
|
|
def eval_modal(model_path):
|
|
run = {}
|
|
model = CrossEncoder(model_path, max_length=512)
|
|
|
|
for qid in relevant_qid:
|
|
query = queries_eval[qid]
|
|
|
|
cand = passage_cand[qid]
|
|
pids = [c[0] for c in cand]
|
|
corpus_sentences = [c[1] for c in cand]
|
|
|
|
## CrossEncoder
|
|
cross_inp = [[query, sent] for sent in corpus_sentences]
|
|
if model.config.num_labels > 1:
|
|
cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
|
|
else:
|
|
cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist()
|
|
|
|
cross_scores_sparse = {}
|
|
for idx, pid in enumerate(pids):
|
|
cross_scores_sparse[pid] = cross_scores[idx]
|
|
|
|
sparse_scores = cross_scores_sparse
|
|
run[qid] = {}
|
|
for pid in sparse_scores:
|
|
run[qid][pid] = float(sparse_scores[pid])
|
|
|
|
evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'})
|
|
scores = evaluator.evaluate(run)
|
|
scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()])
|
|
|
|
print("NDCG@10: {:.2f}".format(scores_mean * 100))
|
|
return scores_mean
|
|
|
|
################################
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
config = AutoConfig.from_pretrained(model_name)
|
|
config.num_labels = 1
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
|
############# Remove layers
|
|
if len(sys.argv) > 2:
|
|
num_layers = int(sys.argv[2])
|
|
if num_layers == 6:
|
|
layers_to_keep = [0, 2, 4, 6, 8, 10] #6 Layers
|
|
elif num_layers == 4:
|
|
layers_to_keep = [1, 4, 7, 10] #4 Layers
|
|
elif num_layers == 2:
|
|
layers_to_keep = [3, 7] #2 Layers
|
|
else:
|
|
print("Unknown number of layers to keep:", num_layers)
|
|
exit()
|
|
|
|
print("Reduce model to {} layers".format(len(layers_to_keep)))
|
|
new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(model.bert.encoder.layer) if i in layers_to_keep])
|
|
model.bert.encoder.layer = new_layers
|
|
model.bert.config.num_hidden_layers = len(layers_to_keep)
|
|
model_name += "_L-{}".format(len(layers_to_keep))
|
|
|
|
|
|
|
|
|
|
#######################
|
|
|
|
queries = {}
|
|
corpus = {}
|
|
|
|
output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
|
output_save_path_latest = output_save_path+"-latest"
|
|
tokenizer.save_pretrained(output_save_path)
|
|
tokenizer.save_pretrained(output_save_path_latest)
|
|
|
|
|
|
# Write self to path
|
|
train_script_path = os.path.join(output_save_path, 'train_script.py')
|
|
copyfile(__file__, train_script_path)
|
|
with open(train_script_path, 'a') as fOut:
|
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
|
|
|
|
|
####
|
|
train_script_path = os.path.join(output_save_path_latest, 'train_script.py')
|
|
copyfile(__file__, train_script_path)
|
|
with open(train_script_path, 'a') as fOut:
|
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
|
|
|
|
|
|
|
#### Read train files
|
|
class MultilingualDataset(Dataset):
|
|
def __init__(self):
|
|
self.examples = defaultdict(lambda: defaultdict(list)) #[id][lang] => [samples...]
|
|
|
|
def add(self, lang, filepath):
|
|
open_method = gzip.open if filepath.endswith('.gz') else open
|
|
with open_method(filepath, 'rt') as fIn:
|
|
for line in fIn:
|
|
pid, passage = line.strip().split("\t")
|
|
self.examples[pid][lang].append(passage)
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, item):
|
|
all_examples = self.examples[item] #All examples in all languages
|
|
lang_examples = random.choice(list(all_examples.values())) #Examples in on specific language
|
|
return random.choice(lang_examples) #One random example
|
|
|
|
|
|
train_corpus = MultilingualDataset()
|
|
train_corpus.add('en', 'msmarco-data/collection.tsv')
|
|
train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz')
|
|
train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz')
|
|
|
|
|
|
train_queries = MultilingualDataset()
|
|
train_queries.add('en', 'msmarco-data/queries.train.tsv')
|
|
train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz')
|
|
train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz')
|
|
|
|
############## MSE Dataset
|
|
class MSEDataset(Dataset):
|
|
def __init__(self, filepath):
|
|
super().__init__()
|
|
|
|
self.examples = []
|
|
with open(filepath) as fIn:
|
|
for line in fIn:
|
|
pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
|
|
self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)])
|
|
|
|
def __len__(self):
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, item):
|
|
return self.examples[item]
|
|
|
|
train_batch_size = 16
|
|
train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv')
|
|
train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size)
|
|
|
|
|
|
############## Optimizer
|
|
|
|
weight_decay = 0.01
|
|
max_grad_norm = 1
|
|
param_optimizer = list(model.named_parameters())
|
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
|
optimizer_grouped_parameters = [
|
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
|
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
|
]
|
|
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
|
|
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader))
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
|
loss_fct = torch.nn.MSELoss()
|
|
### Start training
|
|
model.to(device)
|
|
|
|
auto_save = 10000
|
|
best_ndcg_score = 0
|
|
for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
|
|
batch_queries = [train_queries[qid] for qid in batch[0]]
|
|
batch_pos = [train_corpus[cid] for cid in batch[1]]
|
|
batch_neg = [train_corpus[cid] for cid in batch[2]]
|
|
scores = batch[3].float().to(device) #torch.tensor(batch[3], dtype=torch.float, device=device)
|
|
|
|
with autocast():
|
|
inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
|
|
pred_pos = model(**inp_pos).logits.squeeze()
|
|
|
|
inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
|
|
pred_neg = model(**inp_neg).logits.squeeze()
|
|
|
|
pred_diff = pred_pos - pred_neg
|
|
loss_value = loss_fct(pred_diff, scores)
|
|
|
|
|
|
scaler.scale(loss_value).backward()
|
|
scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
|
|
optimizer.zero_grad()
|
|
scheduler.step()
|
|
|
|
if (step_idx+1) % auto_save == 0:
|
|
print("Step:", step_idx+1)
|
|
model.save_pretrained(output_save_path_latest)
|
|
ndcg_score = eval_modal(output_save_path_latest)
|
|
|
|
if ndcg_score >= best_ndcg_score:
|
|
best_ndcg_score = ndcg_score
|
|
print("Save to:", output_save_path)
|
|
model.save_pretrained(output_save_path)
|
|
|
|
model.save_pretrained(output_save_path)
|
|
|
|
|
|
# Script was called via:
|
|
#python train_cross-encoder_mse_multilingual.py microsoft/Multilingual-MiniLM-L12-H384 6 |