初始化项目,由ModelHub XC社区提供模型
Model: cross-encoder/ms-marco-TinyBERT-L6 Source: Original Platform
This commit is contained in:
10
.gitattributes
vendored
Normal file
10
.gitattributes
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
127
CERerankingEvaluator_results.csv
Normal file
127
CERerankingEvaluator_results.csv
Normal file
@@ -0,0 +1,127 @@
|
||||
epoch,steps,MRR@10
|
||||
0,5000,0.5650095238095237
|
||||
0,10000,0.5849968253968254
|
||||
0,15000,0.6097650793650794
|
||||
0,20000,0.6246285714285715
|
||||
0,25000,0.6100253968253967
|
||||
0,30000,0.6270730158730159
|
||||
0,35000,0.6138888888888888
|
||||
0,40000,0.6240317460317462
|
||||
0,45000,0.6327619047619049
|
||||
0,50000,0.619631746031746
|
||||
0,55000,0.5871142857142856
|
||||
0,60000,0.6175809523809525
|
||||
0,65000,0.6081968253968254
|
||||
0,70000,0.6151301587301587
|
||||
0,75000,0.6093269841269842
|
||||
0,80000,0.6032571428571428
|
||||
0,85000,0.6138063492063491
|
||||
0,90000,0.6156952380952381
|
||||
0,95000,0.6303523809523809
|
||||
0,100000,0.6061523809523809
|
||||
0,105000,0.6133174603174603
|
||||
0,110000,0.6226063492063493
|
||||
0,115000,0.6176349206349206
|
||||
0,120000,0.6104761904761905
|
||||
0,125000,0.6332253968253967
|
||||
0,130000,0.6289523809523808
|
||||
0,135000,0.6181809523809524
|
||||
0,140000,0.6399841269841271
|
||||
0,145000,0.623073015873016
|
||||
0,150000,0.5963587301587302
|
||||
0,155000,0.6157301587301588
|
||||
0,160000,0.613120634920635
|
||||
0,165000,0.6089936507936508
|
||||
0,170000,0.6203301587301587
|
||||
0,175000,0.6171269841269841
|
||||
0,180000,0.5939269841269841
|
||||
0,185000,0.6417873015873015
|
||||
0,190000,0.6164476190476191
|
||||
0,195000,0.6215841269841269
|
||||
0,200000,0.6298984126984126
|
||||
0,205000,0.6030507936507936
|
||||
0,210000,0.6084730158730158
|
||||
0,215000,0.6092730158730159
|
||||
0,220000,0.5939650793650793
|
||||
0,225000,0.6124190476190475
|
||||
0,230000,0.6039269841269841
|
||||
0,235000,0.6253301587301587
|
||||
0,240000,0.634904761904762
|
||||
0,245000,0.6317015873015873
|
||||
0,250000,0.6196603174603175
|
||||
0,255000,0.6287396825396825
|
||||
0,260000,0.6095746031746031
|
||||
0,265000,0.6263492063492063
|
||||
0,270000,0.6171079365079365
|
||||
0,275000,0.6289523809523809
|
||||
0,280000,0.6202634920634921
|
||||
0,285000,0.6255301587301587
|
||||
0,290000,0.5993841269841268
|
||||
0,295000,0.6191841269841271
|
||||
0,300000,0.6203396825396825
|
||||
0,305000,0.6128412698412699
|
||||
0,310000,0.6090825396825398
|
||||
0,315000,0.5950539682539682
|
||||
0,320000,0.5990444444444444
|
||||
0,325000,0.6042412698412698
|
||||
0,330000,0.5960190476190476
|
||||
0,335000,0.6106222222222223
|
||||
0,340000,0.6055968253968255
|
||||
0,345000,0.5984095238095238
|
||||
0,350000,0.6142984126984128
|
||||
0,355000,0.6137746031746032
|
||||
0,360000,0.6018412698412698
|
||||
0,365000,0.6123079365079365
|
||||
0,370000,0.6130285714285715
|
||||
0,375000,0.6008412698412698
|
||||
0,380000,0.6020698412698412
|
||||
0,385000,0.6100222222222222
|
||||
0,390000,0.5971650793650793
|
||||
0,395000,0.5941968253968255
|
||||
0,400000,0.5871428571428571
|
||||
0,405000,0.6100190476190476
|
||||
0,410000,0.5903174603174602
|
||||
0,415000,0.5988317460317459
|
||||
0,420000,0.6132380952380952
|
||||
0,425000,0.6144412698412698
|
||||
0,430000,0.5980888888888888
|
||||
0,435000,0.5973746031746032
|
||||
0,440000,0.595384126984127
|
||||
0,445000,0.5871714285714286
|
||||
0,450000,0.6012412698412699
|
||||
0,455000,0.5873047619047618
|
||||
0,460000,0.595584126984127
|
||||
0,465000,0.5804285714285713
|
||||
0,470000,0.5887619047619047
|
||||
0,475000,0.5872761904761904
|
||||
0,480000,0.5871396825396825
|
||||
0,485000,0.5907174603174602
|
||||
0,490000,0.5880412698412699
|
||||
0,495000,0.5807968253968254
|
||||
0,500000,0.5909746031746032
|
||||
0,505000,0.5912984126984128
|
||||
0,510000,0.5942761904761905
|
||||
0,515000,0.5840222222222223
|
||||
0,520000,0.5852380952380952
|
||||
0,525000,0.582784126984127
|
||||
0,530000,0.5916190476190476
|
||||
0,535000,0.5777269841269841
|
||||
0,540000,0.582120634920635
|
||||
0,545000,0.5746634920634921
|
||||
0,550000,0.5746444444444445
|
||||
0,555000,0.5632444444444444
|
||||
0,560000,0.5799650793650795
|
||||
0,565000,0.5932507936507936
|
||||
0,570000,0.5816190476190476
|
||||
0,575000,0.5838857142857143
|
||||
0,580000,0.5859650793650794
|
||||
0,585000,0.5843968253968255
|
||||
0,590000,0.5840634920634921
|
||||
0,595000,0.5958285714285714
|
||||
0,600000,0.5842857142857142
|
||||
0,605000,0.5892507936507937
|
||||
0,610000,0.5914507936507937
|
||||
0,615000,0.5953968253968254
|
||||
0,620000,0.5925174603174603
|
||||
0,625000,0.5890857142857143
|
||||
0,-1,0.5890857142857143
|
||||
|
75
README.md
Normal file
75
README.md
Normal file
@@ -0,0 +1,75 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
datasets:
|
||||
- sentence-transformers/msmarco
|
||||
language:
|
||||
- en
|
||||
base_model:
|
||||
- nreimers/TinyBERT_L-6_H-768_v2
|
||||
pipeline_tag: text-ranking
|
||||
library_name: sentence-transformers
|
||||
tags:
|
||||
- transformers
|
||||
---
|
||||
# Cross-Encoder for MS Marco
|
||||
|
||||
This model was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
|
||||
|
||||
The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/cross_encoder/training/ms_marco)
|
||||
|
||||
|
||||
## Usage with Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
||||
import torch
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-TinyBERT-L6')
|
||||
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-TinyBERT-L6')
|
||||
|
||||
features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
scores = model(**features).logits
|
||||
print(scores)
|
||||
```
|
||||
|
||||
|
||||
## Usage with SentenceTransformers
|
||||
|
||||
The usage becomes easier when you have [SentenceTransformers](https://www.sbert.net/) installed. Then, you can use the pre-trained models like this:
|
||||
```python
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L6', max_length=512)
|
||||
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
|
||||
```
|
||||
|
||||
|
||||
## Performance
|
||||
In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
|
||||
|
||||
|
||||
| Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec |
|
||||
| ------------- |:-------------| -----| --- |
|
||||
| **Version 2 models** | | |
|
||||
| cross-encoder/ms-marco-TinyBERT-L2-v2 | 69.84 | 32.56 | 9000
|
||||
| cross-encoder/ms-marco-MiniLM-L2-v2 | 71.01 | 34.85 | 4100
|
||||
| cross-encoder/ms-marco-MiniLM-L4-v2 | 73.04 | 37.70 | 2500
|
||||
| cross-encoder/ms-marco-MiniLM-L6-v2 | 74.30 | 39.01 | 1800
|
||||
| cross-encoder/ms-marco-MiniLM-L12-v2 | 74.31 | 39.02 | 960
|
||||
| **Version 1 models** | | |
|
||||
| cross-encoder/ms-marco-TinyBERT-L2 | 67.43 | 30.15 | 9000
|
||||
| cross-encoder/ms-marco-TinyBERT-L4 | 68.09 | 34.50 | 2900
|
||||
| cross-encoder/ms-marco-TinyBERT-L6 | 69.57 | 36.13 | 680
|
||||
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340
|
||||
| **Other models** | | |
|
||||
| nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900
|
||||
| nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340
|
||||
| nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100
|
||||
| Capreolus/electra-base-msmarco | 71.23 | 36.89 | 340
|
||||
| amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | 35.54 | 330
|
||||
| sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco | 72.82 | 37.88 | 720
|
||||
|
||||
Note: Runtime was computed on a V100 GPU.
|
||||
27
config.json
Normal file
27
config.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"_name_or_path": "nreimers/TinyBERT_L-6_H-768_v2",
|
||||
"architectures": [
|
||||
"BertForSequenceClassification"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 768,
|
||||
"id2label": {
|
||||
"0": "LABEL_0"
|
||||
},
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"label2id": {
|
||||
"LABEL_0": 0
|
||||
},
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 6,
|
||||
"pad_token_id": 0,
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522
|
||||
}
|
||||
3
flax_model.msgpack
Normal file
3
flax_model.msgpack
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8dca17d03d10927dc2cae61af57f617da0cbeb4cf86318fa6e4f866ebe177359
|
||||
size 267826914
|
||||
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:44e975b3a25903d3d9cefbddc06505a0d6cc2eb8ff311f0da6ace9118b366753
|
||||
size 267839500
|
||||
3
onnx/model.onnx
Normal file
3
onnx/model.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:76285b39ad4f9fe2638aeb4d4373e839644b6a46c57a94e7df34aa2f1d5910a5
|
||||
size 267979984
|
||||
3
onnx/model_O1.onnx
Normal file
3
onnx/model_O1.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:41d359e2246be0906dd9d02ff71129f332e4814d9d7eb27752da54d6de3a1ca8
|
||||
size 267931340
|
||||
3
onnx/model_O2.onnx
Normal file
3
onnx/model_O2.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:58c001c932ae5f292b7b00069bb30eae34f9e0418a1c91b3e325d4716f148c6d
|
||||
size 267846024
|
||||
3
onnx/model_O3.onnx
Normal file
3
onnx/model_O3.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ac0afd86cd65c5f238591501b55a156d6ab69979236ca1f44c9e501864dff16b
|
||||
size 267845955
|
||||
3
onnx/model_O4.onnx
Normal file
3
onnx/model_O4.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a36dfbc48c898574fd6e0cae98a4ae3ca2e049669e783805180fd6d19fd8934a
|
||||
size 134000992
|
||||
3
onnx/model_qint8_arm64.onnx
Normal file
3
onnx/model_qint8_arm64.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7bf5d227ffe3b070d4704a1646613ddf6df77a120ea683369583e144fcf4f12e
|
||||
size 67641829
|
||||
3
onnx/model_qint8_avx512.onnx
Normal file
3
onnx/model_qint8_avx512.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7bf5d227ffe3b070d4704a1646613ddf6df77a120ea683369583e144fcf4f12e
|
||||
size 67641829
|
||||
3
onnx/model_qint8_avx512_vnni.onnx
Normal file
3
onnx/model_qint8_avx512_vnni.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7bf5d227ffe3b070d4704a1646613ddf6df77a120ea683369583e144fcf4f12e
|
||||
size 67641829
|
||||
3
onnx/model_quint8_avx2.onnx
Normal file
3
onnx/model_quint8_avx2.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6ab7d03ca90b69264bdaa5014ce9d8beef6df39afd3eccd2734f75bd251299bd
|
||||
size 67641828
|
||||
3
openvino/openvino_model.bin
Normal file
3
openvino/openvino_model.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2bfbc62d0401ad098b9ab737425ad055ebd0d9fda68a7da4847c1e992c4afddb
|
||||
size 267827364
|
||||
6647
openvino/openvino_model.xml
Normal file
6647
openvino/openvino_model.xml
Normal file
File diff suppressed because it is too large
Load Diff
3
openvino/openvino_model_qint8_quantized.bin
Normal file
3
openvino/openvino_model_qint8_quantized.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d57aa37147cef8bca75421c893d64a50eb6a3fba7b2151d3c05166908f984fd3
|
||||
size 67575240
|
||||
11319
openvino/openvino_model_qint8_quantized.xml
Normal file
11319
openvino/openvino_model_qint8_quantized.xml
Normal file
File diff suppressed because it is too large
Load Diff
3
pytorch_model.bin
Normal file
3
pytorch_model.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0cf70691b32e33dc1a57845ceeeb81bc70cfec62a0d584d063f13576403f2759
|
||||
size 267871721
|
||||
7
special_tokens_map.json
Normal file
7
special_tokens_map.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
30672
tokenizer.json
Normal file
30672
tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
58
tokenizer_config.json
Normal file
58
tokenizer_config.json
Normal file
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "[PAD]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"100": {
|
||||
"content": "[UNK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"101": {
|
||||
"content": "[CLS]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"102": {
|
||||
"content": "[SEP]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"103": {
|
||||
"content": "[MASK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"extra_special_tokens": {},
|
||||
"mask_token": "[MASK]",
|
||||
"model_max_length": 512,
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
193
train_script.py
Normal file
193
train_script.py
Normal file
@@ -0,0 +1,193 @@
|
||||
from torch.utils.data import DataLoader
|
||||
from sentence_transformers import LoggingHandler
|
||||
from sentence_transformers.cross_encoder import CrossEncoder
|
||||
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
|
||||
from sentence_transformers import InputExample
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import gzip
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
from shutil import copyfile
|
||||
import csv
|
||||
import tqdm
|
||||
|
||||
#### Just some code to print debug information to stdout
|
||||
logging.basicConfig(format='%(asctime)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S',
|
||||
level=logging.INFO,
|
||||
handlers=[LoggingHandler()])
|
||||
#### /print debug information to stdout
|
||||
|
||||
|
||||
#Define our Cross-Encoder
|
||||
model_name = sys.argv[1] #'google/electra-small-discriminator'
|
||||
train_batch_size = 32
|
||||
num_epochs = 1
|
||||
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
#We set num_labels=1, which predicts a continous score between 0 and 1
|
||||
model = CrossEncoder(model_name, num_labels=1, max_length=512)
|
||||
|
||||
|
||||
# Write self to path
|
||||
os.makedirs(model_save_path, exist_ok=True)
|
||||
|
||||
train_script_path = os.path.join(model_save_path, 'train_script.py')
|
||||
copyfile(__file__, train_script_path)
|
||||
with open(train_script_path, 'a') as fOut:
|
||||
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
||||
|
||||
|
||||
corpus = {}
|
||||
queries = {}
|
||||
|
||||
#### Read train file
|
||||
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
|
||||
for line in fIn:
|
||||
pid, passage = line.strip().split("\t")
|
||||
corpus[pid] = passage
|
||||
|
||||
with open('../data/queries.train.tsv', 'r') as fIn:
|
||||
for line in fIn:
|
||||
qid, query = line.strip().split("\t")
|
||||
queries[qid] = query
|
||||
|
||||
|
||||
|
||||
pos_neg_ration = (4+1)
|
||||
cnt = 0
|
||||
train_samples = []
|
||||
dev_samples = {}
|
||||
|
||||
num_dev_queries = 125
|
||||
num_max_dev_negatives = 200
|
||||
|
||||
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
|
||||
for line in fIn:
|
||||
qid, pos_id, neg_id = line.strip().split()
|
||||
|
||||
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
|
||||
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
|
||||
|
||||
if qid in dev_samples:
|
||||
dev_samples[qid]['positive'].add(corpus[pos_id])
|
||||
|
||||
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
|
||||
dev_samples[qid]['negative'].add(corpus[neg_id])
|
||||
|
||||
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
|
||||
for line in tqdm.tqdm(fIn, unit_scale=True):
|
||||
cnt += 1
|
||||
qid, pos_id, neg_id = line.strip().split()
|
||||
query = queries[qid]
|
||||
if (cnt % pos_neg_ration) == 0:
|
||||
passage = corpus[pos_id]
|
||||
label = 1
|
||||
else:
|
||||
passage = corpus[neg_id]
|
||||
label = 0
|
||||
|
||||
train_samples.append(InputExample(texts=[query, passage], label=label))
|
||||
|
||||
if len(train_samples) >= 2e7:
|
||||
break
|
||||
|
||||
|
||||
|
||||
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
|
||||
|
||||
# We add an evaluator, which evaluates the performance during training
|
||||
|
||||
class CERerankingEvaluator:
|
||||
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
|
||||
self.samples = samples
|
||||
self.name = name
|
||||
self.mrr_at_k = mrr_at_k
|
||||
|
||||
if isinstance(self.samples, dict):
|
||||
self.samples = list(self.samples.values())
|
||||
|
||||
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
|
||||
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
|
||||
|
||||
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
|
||||
if epoch != -1:
|
||||
if steps == -1:
|
||||
out_txt = " after epoch {}:".format(epoch)
|
||||
else:
|
||||
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
|
||||
else:
|
||||
out_txt = ":"
|
||||
|
||||
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
|
||||
|
||||
all_mrr_scores = []
|
||||
num_queries = 0
|
||||
num_positives = []
|
||||
num_negatives = []
|
||||
for instance in self.samples:
|
||||
query = instance['query']
|
||||
positive = list(instance['positive'])
|
||||
negative = list(instance['negative'])
|
||||
docs = positive + negative
|
||||
is_relevant = [True]*len(positive) + [False]*len(negative)
|
||||
|
||||
if len(positive) == 0 or len(negative) == 0:
|
||||
continue
|
||||
|
||||
num_queries += 1
|
||||
num_positives.append(len(positive))
|
||||
num_negatives.append(len(negative))
|
||||
|
||||
model_input = [[query, doc] for doc in docs]
|
||||
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
|
||||
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
|
||||
|
||||
mrr_score = 0
|
||||
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
|
||||
if is_relevant[index]:
|
||||
mrr_score = 1 / (rank+1)
|
||||
|
||||
all_mrr_scores.append(mrr_score)
|
||||
|
||||
mean_mrr = np.mean(all_mrr_scores)
|
||||
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
|
||||
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
|
||||
|
||||
if output_path is not None:
|
||||
csv_path = os.path.join(output_path, self.csv_file)
|
||||
output_file_exists = os.path.isfile(csv_path)
|
||||
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
|
||||
writer = csv.writer(f)
|
||||
if not output_file_exists:
|
||||
writer.writerow(self.csv_headers)
|
||||
|
||||
writer.writerow([epoch, steps, mean_mrr])
|
||||
|
||||
return mean_mrr
|
||||
|
||||
|
||||
evaluator = CERerankingEvaluator(dev_samples)
|
||||
|
||||
# Configure the training
|
||||
warmup_steps = 5000
|
||||
logging.info("Warmup-steps: {}".format(warmup_steps))
|
||||
|
||||
|
||||
# Train the model
|
||||
model.fit(train_dataloader=train_dataloader,
|
||||
evaluator=evaluator,
|
||||
epochs=num_epochs,
|
||||
evaluation_steps=5000,
|
||||
warmup_steps=warmup_steps,
|
||||
output_path=model_save_path,
|
||||
use_amp=True)
|
||||
|
||||
#Save latest model
|
||||
model.save(model_save_path+'-latest')
|
||||
|
||||
|
||||
# Script was called via:
|
||||
#python train_cross-encoder.py nreimers/TinyBERT_L-6_H-768_v2
|
||||
Reference in New Issue
Block a user