初始化项目,由ModelHub XC社区提供模型
Model: cross-encoder/msmarco-MiniLM-L6-en-de-v1 Source: Original Platform
This commit is contained in:
18
.gitattributes
vendored
Normal file
18
.gitattributes
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
85
README.md
Normal file
85
README.md
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
datasets:
|
||||
- sentence-transformers/msmarco
|
||||
language:
|
||||
- en
|
||||
- de
|
||||
base_model:
|
||||
- microsoft/Multilingual-MiniLM-L12-H384
|
||||
pipeline_tag: text-ranking
|
||||
library_name: sentence-transformers
|
||||
tags:
|
||||
- transformers
|
||||
---
|
||||
# Cross-Encoder for MS MARCO - EN-DE
|
||||
|
||||
This is a cross-lingual Cross-Encoder model for EN-DE that can be used for passage re-ranking. It was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
|
||||
|
||||
The model can be used for Information Retrieval: See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html).
|
||||
|
||||
The training code is available in this repository, see `train_script.py`.
|
||||
|
||||
|
||||
## Usage with SentenceTransformers
|
||||
|
||||
When you have [SentenceTransformers](https://www.sbert.net/) installed, you can use the model like this:
|
||||
```python
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
model = CrossEncoder('model_name', max_length=512)
|
||||
|
||||
query = 'How many people live in Berlin?'
|
||||
docs = ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.']
|
||||
pairs = [(query, doc) for doc in docs]
|
||||
scores = model.predict(pairs)
|
||||
```
|
||||
|
||||
|
||||
## Usage with Transformers
|
||||
With the transformers library, you can use the model like this:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
import torch
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained('model_name')
|
||||
tokenizer = AutoTokenizer.from_pretrained('model_name')
|
||||
|
||||
features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
scores = model(**features).logits
|
||||
print(scores)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
## Performance
|
||||
The performance was evaluated on three datasets:
|
||||
- **TREC-DL19 EN-EN**: The original [TREC 2019 Deep Learning Track](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019.html): Given an English query and 1000 documents (retrieved by BM25 lexical search), rank documents with according to their relevance. We compute NDCG@10. BM25 achieves a score of 45.46, a perfect re-ranker can achieve a score of 95.47.
|
||||
- **TREC-DL19 DE-EN**: The English queries of TREC-DL19 have been translated by a German native speaker to German. We rank the German queries versus the English passages from the original TREC-DL19 setup. We compute NDCG@10.
|
||||
- **GermanDPR DE-DE**: The [GermanDPR](https://www.deepset.ai/germanquad) dataset provides German queries and German passages from Wikipedia. We indexed the 2.8 Million paragraphs from German Wikipedia and retrieved for each query the top 100 most relevant passages using BM25 lexical search with Elasticsearch. We compute MRR@10. BM25 achieves a score of 35.85, a perfect re-ranker can achieve a score of 76.27.
|
||||
|
||||
We also check the performance of bi-encoders using the same evaluation: The retrieved documents from BM25 lexical search are re-ranked using query & passage embeddings with cosine-similarity. Bi-Encoders can also be used for end-to-end semantic search.
|
||||
|
||||
|
||||
| Model-Name | TREC-DL19 EN-EN | TREC-DL19 DE-EN | GermanDPR DE-DE | Docs / Sec |
|
||||
| ------------- |:-------------:| :-----: | :---: | :----: |
|
||||
| BM25 | 45.46 | - | 35.85 | -|
|
||||
| **Cross-Encoder Re-Rankers** | | | |
|
||||
| [cross-encoder/msmarco-MiniLM-L6-en-de-v1](https://huggingface.co/cross-encoder/msmarco-MiniLM-L6-en-de-v1) | 72.43 | 65.53 | 46.77 | 1600 |
|
||||
| [cross-encoder/msmarco-MiniLM-L12-en-de-v1](https://huggingface.co/cross-encoder/msmarco-MiniLM-L12-en-de-v1) | 72.94 | 66.07 | 49.91 | 900 |
|
||||
| [svalabs/cross-electra-ms-marco-german-uncased](https://huggingface.co/svalabs/cross-electra-ms-marco-german-uncased) (DE only) | - | - | 53.67 | 260 |
|
||||
| [deepset/gbert-base-germandpr-reranking](https://huggingface.co/deepset/gbert-base-germandpr-reranking) (DE only) | - | - | 53.59 | 260 |
|
||||
| **Bi-Encoders (re-ranking)** | | | |
|
||||
| [sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-lng-aligned](https://huggingface.co/sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-lng-aligned) | 63.38 | 58.28 | 37.88 | 940 |
|
||||
| [sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-trained-scratch](https://huggingface.co/sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-trained-scratch) | 65.51 | 58.69 | 38.32 | 940 |
|
||||
| [svalabs/bi-electra-ms-marco-german-uncased](https://huggingface.co/svalabs/bi-electra-ms-marco-german-uncased) (DE only) | - | - | 34.31 | 450 |
|
||||
| [deepset/gbert-base-germandpr-question_encoder](https://huggingface.co/deepset/gbert-base-germandpr-question_encoder) (DE only) | - | - | 42.55 | 450 |
|
||||
|
||||
|
||||
|
||||
Note: Docs / Sec gives the number of (query, document) pairs we can re-rank within a second on a V100 GPU.
|
||||
32
config.json
Executable file
32
config.json
Executable file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"_name_or_path": "microsoft/Multilingual-MiniLM-L12-H384",
|
||||
"architectures": [
|
||||
"BertForSequenceClassification"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 384,
|
||||
"id2label": {
|
||||
"0": "LABEL_0"
|
||||
},
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 1536,
|
||||
"label2id": {
|
||||
"LABEL_0": 0
|
||||
},
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 6,
|
||||
"pad_token_id": 0,
|
||||
"position_embedding_type": "absolute",
|
||||
"tokenizer_class": "XLMRobertaTokenizer",
|
||||
"transformers_version": "4.6.1",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 250037,
|
||||
"sbert_ce_default_activation_function": "torch.nn.modules.linear.Identity"
|
||||
}
|
||||
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:71ce561a97de0f28e9d5347210f526eb2091fb14a89e2543d2160dc472751588
|
||||
size 428045844
|
||||
3
onnx/model.onnx
Normal file
3
onnx/model.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:31e0be7b1ecfa83e9e28bed1c00a8db3f85901cea38c3d0c6932e94f767ac180
|
||||
size 428186273
|
||||
3
onnx/model_O1.onnx
Normal file
3
onnx/model_O1.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cdd4bb545ad76e5318d899851a3f0e4fa291066a6de2d87f625a78207e66bfa5
|
||||
size 428137629
|
||||
3
onnx/model_O2.onnx
Normal file
3
onnx/model_O2.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f5c8598992461182f4cadc6e5d5130ecf3105626942ec1fe40b0a54f8d0a7d6a
|
||||
size 428052337
|
||||
3
onnx/model_O3.onnx
Normal file
3
onnx/model_O3.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6502966f25762cc9d8ed139ccda76fa0a929481226a7c8ae641c872ebb2e7313
|
||||
size 428052268
|
||||
3
onnx/model_O4.onnx
Normal file
3
onnx/model_O4.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9af820cbe6c3e0375cb6fe435c6f1466b64c3a03835cf55079704a704d4343b8
|
||||
size 214104136
|
||||
3
onnx/model_qint8_arm64.onnx
Normal file
3
onnx/model_qint8_arm64.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bee1436cd1041bab174dd26d9cb1a7795ad14df79e55f8b9e4f538c96efa9b21
|
||||
size 107494478
|
||||
3
onnx/model_qint8_avx512.onnx
Normal file
3
onnx/model_qint8_avx512.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bee1436cd1041bab174dd26d9cb1a7795ad14df79e55f8b9e4f538c96efa9b21
|
||||
size 107494478
|
||||
3
onnx/model_qint8_avx512_vnni.onnx
Normal file
3
onnx/model_qint8_avx512_vnni.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bee1436cd1041bab174dd26d9cb1a7795ad14df79e55f8b9e4f538c96efa9b21
|
||||
size 107494478
|
||||
3
onnx/model_quint8_avx2.onnx
Normal file
3
onnx/model_quint8_avx2.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5d849b007fb3975d130f6f476eea9977980f5f260cbef8b0c4ba00efac67f1ab
|
||||
size 107494476
|
||||
3
openvino/openvino_model.bin
Normal file
3
openvino/openvino_model.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:361ea4de051f9271b6b66612f9df76b4b76c4953933b2a4fd9aa34fe81fd1f8a
|
||||
size 428033696
|
||||
6617
openvino/openvino_model.xml
Normal file
6617
openvino/openvino_model.xml
Normal file
File diff suppressed because it is too large
Load Diff
3
pytorch_model.bin
Executable file
3
pytorch_model.bin
Executable file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6d211f85e0367b25564e5ff868d232285b1ab4b982589291546efb9fc6a76aa7
|
||||
size 428078057
|
||||
3
sentencepiece.bpe.model
Executable file
3
sentencepiece.bpe.model
Executable file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
|
||||
size 5069051
|
||||
1
special_tokens_map.json
Executable file
1
special_tokens_map.json
Executable file
@@ -0,0 +1 @@
|
||||
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>"}
|
||||
1
tokenizer.json
Executable file
1
tokenizer.json
Executable file
File diff suppressed because one or more lines are too long
1
tokenizer_config.json
Executable file
1
tokenizer_config.json
Executable file
@@ -0,0 +1 @@
|
||||
{"bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "special_tokens_map_file": "/root/.cache/huggingface/transformers/8ed73a1ab9ef4e90a9451497bf96cfc38d34354352838a143f2dda1c81aed5ca.0dc5b1041f62041ebbd23b1297f2f573769d5c97d8b7c28180ec86b8f6185aa8", "name_or_path": "microsoft/Multilingual-MiniLM-L12-H384", "sp_model_kwargs": {}}
|
||||
273
train_script.py
Executable file
273
train_script.py
Executable file
@@ -0,0 +1,273 @@
|
||||
import gzip
|
||||
import random
|
||||
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW
|
||||
import sys
|
||||
import torch
|
||||
import transformers
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.cuda.amp import autocast
|
||||
import tqdm
|
||||
from datetime import datetime
|
||||
from shutil import copyfile
|
||||
import os
|
||||
####################################
|
||||
|
||||
import gzip
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import tqdm
|
||||
import numpy as np
|
||||
import sys
|
||||
import pytrec_eval
|
||||
from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
||||
import torch
|
||||
|
||||
|
||||
model_name = sys.argv[1]
|
||||
max_length = 350
|
||||
|
||||
######### Evaluation
|
||||
queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz'
|
||||
queries_eval = {}
|
||||
with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
|
||||
for line in fIn:
|
||||
qid, query = line.strip().split("\t")[0:2]
|
||||
queries_eval[qid] = query
|
||||
|
||||
rel = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn:
|
||||
for line in fIn:
|
||||
qid, _, pid, score = line.strip().split()
|
||||
score = int(score)
|
||||
if score > 0:
|
||||
rel[qid][pid] = score
|
||||
|
||||
relevant_qid = []
|
||||
for qid in queries_eval:
|
||||
if len(rel[qid]) > 0:
|
||||
relevant_qid.append(qid)
|
||||
|
||||
# Read top 1k
|
||||
passage_cand = {}
|
||||
|
||||
with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn:
|
||||
for line in fIn:
|
||||
qid, pid, query, passage = line.strip().split("\t")
|
||||
if qid not in passage_cand:
|
||||
passage_cand[qid] = []
|
||||
|
||||
passage_cand[qid].append([pid, passage])
|
||||
|
||||
|
||||
|
||||
def eval_modal(model_path):
|
||||
run = {}
|
||||
model = CrossEncoder(model_path, max_length=512)
|
||||
|
||||
for qid in relevant_qid:
|
||||
query = queries_eval[qid]
|
||||
|
||||
cand = passage_cand[qid]
|
||||
pids = [c[0] for c in cand]
|
||||
corpus_sentences = [c[1] for c in cand]
|
||||
|
||||
## CrossEncoder
|
||||
cross_inp = [[query, sent] for sent in corpus_sentences]
|
||||
if model.config.num_labels > 1:
|
||||
cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
|
||||
else:
|
||||
cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist()
|
||||
|
||||
cross_scores_sparse = {}
|
||||
for idx, pid in enumerate(pids):
|
||||
cross_scores_sparse[pid] = cross_scores[idx]
|
||||
|
||||
sparse_scores = cross_scores_sparse
|
||||
run[qid] = {}
|
||||
for pid in sparse_scores:
|
||||
run[qid][pid] = float(sparse_scores[pid])
|
||||
|
||||
evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'})
|
||||
scores = evaluator.evaluate(run)
|
||||
scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()])
|
||||
|
||||
print("NDCG@10: {:.2f}".format(scores_mean * 100))
|
||||
return scores_mean
|
||||
|
||||
################################
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config.num_labels = 1
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
|
||||
############# Remove layers
|
||||
if len(sys.argv) > 2:
|
||||
num_layers = int(sys.argv[2])
|
||||
if num_layers == 6:
|
||||
layers_to_keep = [0, 2, 4, 6, 8, 10] #6 Layers
|
||||
elif num_layers == 4:
|
||||
layers_to_keep = [1, 4, 7, 10] #4 Layers
|
||||
elif num_layers == 2:
|
||||
layers_to_keep = [3, 7] #2 Layers
|
||||
else:
|
||||
print("Unknown number of layers to keep:", num_layers)
|
||||
exit()
|
||||
|
||||
print("Reduce model to {} layers".format(len(layers_to_keep)))
|
||||
new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(model.bert.encoder.layer) if i in layers_to_keep])
|
||||
model.bert.encoder.layer = new_layers
|
||||
model.bert.config.num_hidden_layers = len(layers_to_keep)
|
||||
model_name += "_L-{}".format(len(layers_to_keep))
|
||||
|
||||
|
||||
|
||||
|
||||
#######################
|
||||
|
||||
queries = {}
|
||||
corpus = {}
|
||||
|
||||
output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
output_save_path_latest = output_save_path+"-latest"
|
||||
tokenizer.save_pretrained(output_save_path)
|
||||
tokenizer.save_pretrained(output_save_path_latest)
|
||||
|
||||
|
||||
# Write self to path
|
||||
train_script_path = os.path.join(output_save_path, 'train_script.py')
|
||||
copyfile(__file__, train_script_path)
|
||||
with open(train_script_path, 'a') as fOut:
|
||||
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
||||
|
||||
|
||||
####
|
||||
train_script_path = os.path.join(output_save_path_latest, 'train_script.py')
|
||||
copyfile(__file__, train_script_path)
|
||||
with open(train_script_path, 'a') as fOut:
|
||||
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
||||
|
||||
|
||||
|
||||
#### Read train files
|
||||
class MultilingualDataset(Dataset):
|
||||
def __init__(self):
|
||||
self.examples = defaultdict(lambda: defaultdict(list)) #[id][lang] => [samples...]
|
||||
|
||||
def add(self, lang, filepath):
|
||||
open_method = gzip.open if filepath.endswith('.gz') else open
|
||||
with open_method(filepath, 'rt') as fIn:
|
||||
for line in fIn:
|
||||
pid, passage = line.strip().split("\t")
|
||||
self.examples[pid][lang].append(passage)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, item):
|
||||
all_examples = self.examples[item] #All examples in all languages
|
||||
lang_examples = random.choice(list(all_examples.values())) #Examples in on specific language
|
||||
return random.choice(lang_examples) #One random example
|
||||
|
||||
|
||||
train_corpus = MultilingualDataset()
|
||||
train_corpus.add('en', 'msmarco-data/collection.tsv')
|
||||
train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz')
|
||||
train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz')
|
||||
|
||||
|
||||
train_queries = MultilingualDataset()
|
||||
train_queries.add('en', 'msmarco-data/queries.train.tsv')
|
||||
train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz')
|
||||
train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz')
|
||||
|
||||
############## MSE Dataset
|
||||
class MSEDataset(Dataset):
|
||||
def __init__(self, filepath):
|
||||
super().__init__()
|
||||
|
||||
self.examples = []
|
||||
with open(filepath) as fIn:
|
||||
for line in fIn:
|
||||
pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
|
||||
self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.examples[item]
|
||||
|
||||
train_batch_size = 16
|
||||
train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv')
|
||||
train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size)
|
||||
|
||||
|
||||
############## Optimizer
|
||||
|
||||
weight_decay = 0.01
|
||||
max_grad_norm = 1
|
||||
param_optimizer = list(model.named_parameters())
|
||||
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
|
||||
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader))
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
loss_fct = torch.nn.MSELoss()
|
||||
### Start training
|
||||
model.to(device)
|
||||
|
||||
auto_save = 10000
|
||||
best_ndcg_score = 0
|
||||
for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
|
||||
batch_queries = [train_queries[qid] for qid in batch[0]]
|
||||
batch_pos = [train_corpus[cid] for cid in batch[1]]
|
||||
batch_neg = [train_corpus[cid] for cid in batch[2]]
|
||||
scores = batch[3].float().to(device) #torch.tensor(batch[3], dtype=torch.float, device=device)
|
||||
|
||||
with autocast():
|
||||
inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
|
||||
pred_pos = model(**inp_pos).logits.squeeze()
|
||||
|
||||
inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
|
||||
pred_neg = model(**inp_neg).logits.squeeze()
|
||||
|
||||
pred_diff = pred_pos - pred_neg
|
||||
loss_value = loss_fct(pred_diff, scores)
|
||||
|
||||
|
||||
scaler.scale(loss_value).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
|
||||
if (step_idx+1) % auto_save == 0:
|
||||
print("Step:", step_idx+1)
|
||||
model.save_pretrained(output_save_path_latest)
|
||||
ndcg_score = eval_modal(output_save_path_latest)
|
||||
|
||||
if ndcg_score >= best_ndcg_score:
|
||||
best_ndcg_score = ndcg_score
|
||||
print("Save to:", output_save_path)
|
||||
model.save_pretrained(output_save_path)
|
||||
|
||||
model.save_pretrained(output_save_path)
|
||||
|
||||
|
||||
# Script was called via:
|
||||
#python train_cross-encoder_mse_multilingual.py microsoft/Multilingual-MiniLM-L12-H384 6
|
||||
Reference in New Issue
Block a user