初始化项目,由ModelHub XC社区提供模型

Model: cross-encoder/ms-marco-TinyBERT-L4
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-05-13 16:08:40 +08:00
commit 2ad1aaaa93
25 changed files with 74486 additions and 0 deletions

10
.gitattributes vendored Normal file
View File

@@ -0,0 +1,10 @@
*.bin.* filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tar.gz filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
model.safetensors filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,127 @@
epoch,steps,MRR@10
0,5000,0.5220761904761905
0,10000,0.5474539682539683
0,15000,0.5190095238095239
0,20000,0.5671206349206349
0,25000,0.5630952380952381
0,30000,0.565215873015873
0,35000,0.5821079365079366
0,40000,0.5496603174603175
0,45000,0.5716761904761904
0,50000,0.5707079365079364
0,55000,0.5709047619047619
0,60000,0.5643650793650793
0,65000,0.5586507936507936
0,70000,0.5792857142857142
0,75000,0.5964857142857143
0,80000,0.5759936507936507
0,85000,0.5640507936507936
0,90000,0.6016063492063491
0,95000,0.594984126984127
0,100000,0.5770507936507936
0,105000,0.605984126984127
0,110000,0.6106380952380953
0,115000,0.5763650793650794
0,120000,0.5977269841269841
0,125000,0.5764190476190475
0,130000,0.5846825396825397
0,135000,0.5810380952380951
0,140000,0.5902317460317461
0,145000,0.6034063492063492
0,150000,0.5953714285714286
0,155000,0.5992349206349207
0,160000,0.6026666666666668
0,165000,0.6046603174603176
0,170000,0.5939269841269842
0,175000,0.6007714285714286
0,180000,0.574752380952381
0,185000,0.5923619047619049
0,190000,0.600431746031746
0,195000,0.6104984126984127
0,200000,0.6154095238095239
0,205000,0.5908285714285714
0,210000,0.590936507936508
0,215000,0.6043174603174603
0,220000,0.6032825396825396
0,225000,0.6210666666666667
0,230000,0.6113396825396825
0,235000,0.6135873015873016
0,240000,0.6162285714285715
0,245000,0.6064317460317461
0,250000,0.6072285714285715
0,255000,0.6073746031746032
0,260000,0.6112857142857142
0,265000,0.6156412698412698
0,270000,0.6350095238095238
0,275000,0.6074158730158731
0,280000,0.6154761904761905
0,285000,0.6236507936507937
0,290000,0.6162412698412697
0,295000,0.616184126984127
0,300000,0.5997523809523809
0,305000,0.5937492063492064
0,310000,0.6227968253968252
0,315000,0.6274952380952382
0,320000,0.6269523809523809
0,325000,0.6306698412698413
0,330000,0.6235079365079363
0,335000,0.6206190476190477
0,340000,0.6209936507936508
0,345000,0.613095238095238
0,350000,0.6196952380952381
0,355000,0.6197301587301588
0,360000,0.6274634920634922
0,365000,0.6152730158730159
0,370000,0.6053968253968254
0,375000,0.615352380952381
0,380000,0.6110285714285715
0,385000,0.621184126984127
0,390000,0.6025619047619047
0,395000,0.6122507936507936
0,400000,0.6189079365079365
0,405000,0.6252285714285714
0,410000,0.6022634920634921
0,415000,0.6053492063492063
0,420000,0.6239619047619047
0,425000,0.6127523809523809
0,430000,0.6231873015873016
0,435000,0.6233968253968254
0,440000,0.6186825396825396
0,445000,0.6279079365079365
0,450000,0.6075079365079366
0,455000,0.603352380952381
0,460000,0.5917142857142857
0,465000,0.5998285714285714
0,470000,0.5949492063492064
0,475000,0.6139714285714285
0,480000,0.6100507936507936
0,485000,0.6057619047619048
0,490000,0.6255714285714286
0,495000,0.6058158730158729
0,500000,0.63004126984127
0,505000,0.6207269841269841
0,510000,0.6126857142857143
0,515000,0.6224825396825397
0,520000,0.6282730158730159
0,525000,0.6256634920634919
0,530000,0.6199079365079365
0,535000,0.6065555555555556
0,540000,0.6166158730158731
0,545000,0.6133936507936507
0,550000,0.6265428571428572
0,555000,0.6077619047619048
0,560000,0.6010984126984126
0,565000,0.6134158730158731
0,570000,0.6211714285714286
0,575000,0.6167301587301588
0,580000,0.6193968253968253
0,585000,0.605352380952381
0,590000,0.6013523809523811
0,595000,0.6070285714285715
0,600000,0.6075492063492064
0,605000,0.6051396825396825
0,610000,0.609984126984127
0,615000,0.6076412698412699
0,620000,0.604384126984127
0,625000,0.6051396825396825
0,-1,0.6051396825396825
1 epoch steps MRR@10
2 0 5000 0.5220761904761905
3 0 10000 0.5474539682539683
4 0 15000 0.5190095238095239
5 0 20000 0.5671206349206349
6 0 25000 0.5630952380952381
7 0 30000 0.565215873015873
8 0 35000 0.5821079365079366
9 0 40000 0.5496603174603175
10 0 45000 0.5716761904761904
11 0 50000 0.5707079365079364
12 0 55000 0.5709047619047619
13 0 60000 0.5643650793650793
14 0 65000 0.5586507936507936
15 0 70000 0.5792857142857142
16 0 75000 0.5964857142857143
17 0 80000 0.5759936507936507
18 0 85000 0.5640507936507936
19 0 90000 0.6016063492063491
20 0 95000 0.594984126984127
21 0 100000 0.5770507936507936
22 0 105000 0.605984126984127
23 0 110000 0.6106380952380953
24 0 115000 0.5763650793650794
25 0 120000 0.5977269841269841
26 0 125000 0.5764190476190475
27 0 130000 0.5846825396825397
28 0 135000 0.5810380952380951
29 0 140000 0.5902317460317461
30 0 145000 0.6034063492063492
31 0 150000 0.5953714285714286
32 0 155000 0.5992349206349207
33 0 160000 0.6026666666666668
34 0 165000 0.6046603174603176
35 0 170000 0.5939269841269842
36 0 175000 0.6007714285714286
37 0 180000 0.574752380952381
38 0 185000 0.5923619047619049
39 0 190000 0.600431746031746
40 0 195000 0.6104984126984127
41 0 200000 0.6154095238095239
42 0 205000 0.5908285714285714
43 0 210000 0.590936507936508
44 0 215000 0.6043174603174603
45 0 220000 0.6032825396825396
46 0 225000 0.6210666666666667
47 0 230000 0.6113396825396825
48 0 235000 0.6135873015873016
49 0 240000 0.6162285714285715
50 0 245000 0.6064317460317461
51 0 250000 0.6072285714285715
52 0 255000 0.6073746031746032
53 0 260000 0.6112857142857142
54 0 265000 0.6156412698412698
55 0 270000 0.6350095238095238
56 0 275000 0.6074158730158731
57 0 280000 0.6154761904761905
58 0 285000 0.6236507936507937
59 0 290000 0.6162412698412697
60 0 295000 0.616184126984127
61 0 300000 0.5997523809523809
62 0 305000 0.5937492063492064
63 0 310000 0.6227968253968252
64 0 315000 0.6274952380952382
65 0 320000 0.6269523809523809
66 0 325000 0.6306698412698413
67 0 330000 0.6235079365079363
68 0 335000 0.6206190476190477
69 0 340000 0.6209936507936508
70 0 345000 0.613095238095238
71 0 350000 0.6196952380952381
72 0 355000 0.6197301587301588
73 0 360000 0.6274634920634922
74 0 365000 0.6152730158730159
75 0 370000 0.6053968253968254
76 0 375000 0.615352380952381
77 0 380000 0.6110285714285715
78 0 385000 0.621184126984127
79 0 390000 0.6025619047619047
80 0 395000 0.6122507936507936
81 0 400000 0.6189079365079365
82 0 405000 0.6252285714285714
83 0 410000 0.6022634920634921
84 0 415000 0.6053492063492063
85 0 420000 0.6239619047619047
86 0 425000 0.6127523809523809
87 0 430000 0.6231873015873016
88 0 435000 0.6233968253968254
89 0 440000 0.6186825396825396
90 0 445000 0.6279079365079365
91 0 450000 0.6075079365079366
92 0 455000 0.603352380952381
93 0 460000 0.5917142857142857
94 0 465000 0.5998285714285714
95 0 470000 0.5949492063492064
96 0 475000 0.6139714285714285
97 0 480000 0.6100507936507936
98 0 485000 0.6057619047619048
99 0 490000 0.6255714285714286
100 0 495000 0.6058158730158729
101 0 500000 0.63004126984127
102 0 505000 0.6207269841269841
103 0 510000 0.6126857142857143
104 0 515000 0.6224825396825397
105 0 520000 0.6282730158730159
106 0 525000 0.6256634920634919
107 0 530000 0.6199079365079365
108 0 535000 0.6065555555555556
109 0 540000 0.6166158730158731
110 0 545000 0.6133936507936507
111 0 550000 0.6265428571428572
112 0 555000 0.6077619047619048
113 0 560000 0.6010984126984126
114 0 565000 0.6134158730158731
115 0 570000 0.6211714285714286
116 0 575000 0.6167301587301588
117 0 580000 0.6193968253968253
118 0 585000 0.605352380952381
119 0 590000 0.6013523809523811
120 0 595000 0.6070285714285715
121 0 600000 0.6075492063492064
122 0 605000 0.6051396825396825
123 0 610000 0.609984126984127
124 0 615000 0.6076412698412699
125 0 620000 0.604384126984127
126 0 625000 0.6051396825396825
127 0 -1 0.6051396825396825

75
README.md Normal file
View File

@@ -0,0 +1,75 @@
---
license: apache-2.0
datasets:
- sentence-transformers/msmarco
language:
- en
base_model:
- nreimers/TinyBERT_L-4_H-312_v2
pipeline_tag: text-ranking
library_name: sentence-transformers
tags:
- transformers
---
# Cross-Encoder for MS Marco
This model was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/cross_encoder/training/ms_marco)
## Usage with Transformers
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-TinyBERT-L4')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-TinyBERT-L4')
features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
model.eval()
with torch.no_grad():
scores = model(**features).logits
print(scores)
```
## Usage with SentenceTransformers
The usage becomes easier when you have [SentenceTransformers](https://www.sbert.net/) installed. Then, you can use the pre-trained models like this:
```python
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L4', max_length=512)
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
```
## Performance
In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
| Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec |
| ------------- |:-------------| -----| --- |
| **Version 2 models** | | |
| cross-encoder/ms-marco-TinyBERT-L2-v2 | 69.84 | 32.56 | 9000
| cross-encoder/ms-marco-MiniLM-L2-v2 | 71.01 | 34.85 | 4100
| cross-encoder/ms-marco-MiniLM-L4-v2 | 73.04 | 37.70 | 2500
| cross-encoder/ms-marco-MiniLM-L6-v2 | 74.30 | 39.01 | 1800
| cross-encoder/ms-marco-MiniLM-L12-v2 | 74.31 | 39.02 | 960
| **Version 1 models** | | |
| cross-encoder/ms-marco-TinyBERT-L2 | 67.43 | 30.15 | 9000
| cross-encoder/ms-marco-TinyBERT-L4 | 68.09 | 34.50 | 2900
| cross-encoder/ms-marco-TinyBERT-L6 | 69.57 | 36.13 | 680
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340
| **Other models** | | |
| nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900
| nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340
| nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100
| Capreolus/electra-base-msmarco | 71.23 | 36.89 | 340
| amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | 35.54 | 330
| sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco | 72.82 | 37.88 | 720
Note: Runtime was computed on a V100 GPU.

27
config.json Normal file
View File

@@ -0,0 +1,27 @@
{
"_name_or_path": "nreimers/TinyBERT_L-4_H-312_v2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 312,
"id2label": {
"0": "LABEL_0"
},
"initializer_range": 0.02,
"intermediate_size": 1200,
"label2id": {
"LABEL_0": 0
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 4,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522
}

3
flax_model.msgpack Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5de37974f4240be0be8cb67d0bd011c96a84e0af7eefc45129d8e95bb8b7cfa2
size 57404914

3
model.safetensors Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c252c9e3b6b9643359d9530048d44e811e7515cc432b192292eb2768722da835
size 57414738

3
onnx/model.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12bf366208f8f62b51d31c1dde6442b62215878ff7e6aec7dad24b365eed0279
size 57510828

3
onnx/model_O1.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c82ce80ac1824aad347af9e792ae60955708b83ad081095b36086f8c25f7fafa
size 57478044

3
onnx/model_O2.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8ea522783cea175573e286e6ebe8cf89ac7b0014a6dd532297c7ec76fab5711b
size 57420067

3
onnx/model_O3.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:51a85a89a75a8543dfcc7a7f2eca8a929b0877608b37704fc0e820b0c7b7f8b3
size 57420025

3
onnx/model_O4.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:99f2b42ebef41788c4b88ad61346ef3cf210f112dca9dec44e49c8ab06758d09
size 28764362

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:be6cd308a6807e03d6ff90bf3666cb4ee4d55352e61fbfa5f5351e4e3d9575cc
size 14656537

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:be6cd308a6807e03d6ff90bf3666cb4ee4d55352e61fbfa5f5351e4e3d9575cc
size 14656537

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:be6cd308a6807e03d6ff90bf3666cb4ee4d55352e61fbfa5f5351e4e3d9575cc
size 14656537

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:40738175864e381e3b0d0d1f172916456c4067f9cf4dc6e7bb799d84abd72253
size 14656535

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:09c8324cc36f5b5f04d7cc31491aebd46a9241fea6bc3dad8c9ddf55487d4812
size 57406500

4753
openvino/openvino_model.xml Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:85759cbab01f9df1bbd135d51138639208815c69a1261b6834c59c11e960dcbe
size 14612160

File diff suppressed because it is too large Load Diff

3
pytorch_model.bin Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b06785cbb737ac18adf49a4ac3ce4d724479afc9c133d5e83172b12db996dc91
size 57436041

7
special_tokens_map.json Normal file
View File

@@ -0,0 +1,7 @@
{
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"unk_token": "[UNK]"
}

30672
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

58
tokenizer_config.json Normal file
View File

@@ -0,0 +1,58 @@
{
"added_tokens_decoder": {
"0": {
"content": "[PAD]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"100": {
"content": "[UNK]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"101": {
"content": "[CLS]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"102": {
"content": "[SEP]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"103": {
"content": "[MASK]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"clean_up_tokenization_spaces": true,
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": true,
"extra_special_tokens": {},
"mask_token": "[MASK]",
"model_max_length": 512,
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]"
}

192
train_script.py Normal file
View File

@@ -0,0 +1,192 @@
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
import gzip
import sys
import numpy as np
import os
from shutil import copyfile
import csv
import tqdm
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
#Define our Cross-Encoder
model_name = sys.argv[1] #'google/electra-small-discriminator'
train_batch_size = 32
num_epochs = 1
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#We set num_labels=1, which predicts a continous score between 0 and 1
model = CrossEncoder(model_name, num_labels=1, max_length=512)
# Write self to path
os.makedirs(model_save_path, exist_ok=True)
train_script_path = os.path.join(model_save_path, 'train_script.py')
copyfile(__file__, train_script_path)
with open(train_script_path, 'a') as fOut:
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
corpus = {}
queries = {}
#### Read train file
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
for line in fIn:
pid, passage = line.strip().split("\t")
corpus[pid] = passage
with open('../data/queries.train.tsv', 'r') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
queries[qid] = query
pos_neg_ration = (4+1)
cnt = 0
train_samples = []
dev_samples = {}
num_dev_queries = 125
num_max_dev_negatives = 200
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
for line in fIn:
qid, pos_id, neg_id = line.strip().split()
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
if qid in dev_samples:
dev_samples[qid]['positive'].add(corpus[pos_id])
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
dev_samples[qid]['negative'].add(corpus[neg_id])
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
for line in tqdm.tqdm(fIn, unit_scale=True):
cnt += 1
qid, pos_id, neg_id = line.strip().split()
query = queries[qid]
if (cnt % pos_neg_ration) == 0:
passage = corpus[pos_id]
label = 1
else:
passage = corpus[neg_id]
label = 0
train_samples.append(InputExample(texts=[query, passage], label=label))
if len(train_samples) >= 2e7:
break
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
# We add an evaluator, which evaluates the performance during training
class CERerankingEvaluator:
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
self.samples = samples
self.name = name
self.mrr_at_k = mrr_at_k
if isinstance(self.samples, dict):
self.samples = list(self.samples.values())
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = " after epoch {}:".format(epoch)
else:
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
all_mrr_scores = []
num_queries = 0
num_positives = []
num_negatives = []
for instance in self.samples:
query = instance['query']
positive = list(instance['positive'])
negative = list(instance['negative'])
docs = positive + negative
is_relevant = [True]*len(positive) + [False]*len(negative)
if len(positive) == 0 or len(negative) == 0:
continue
num_queries += 1
num_positives.append(len(positive))
num_negatives.append(len(negative))
model_input = [[query, doc] for doc in docs]
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
mrr_score = 0
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
if is_relevant[index]:
mrr_score = 1 / (rank+1)
all_mrr_scores.append(mrr_score)
mean_mrr = np.mean(all_mrr_scores)
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
if output_path is not None:
csv_path = os.path.join(output_path, self.csv_file)
output_file_exists = os.path.isfile(csv_path)
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
writer = csv.writer(f)
if not output_file_exists:
writer.writerow(self.csv_headers)
writer.writerow([epoch, steps, mean_mrr])
return mean_mrr
evaluator = CERerankingEvaluator(dev_samples)
# Configure the training
warmup_steps = 5000
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=5000,
warmup_steps=warmup_steps,
output_path=model_save_path,
use_amp=True)
#Save latest model
model.save(model_save_path+'-latest')
# Script was called via:
#python train_cross-encoder.py nreimers/TinyBERT_L-4_H-312_v2

30522
vocab.txt Normal file

File diff suppressed because it is too large Load Diff