初始化项目,由ModelHub XC社区提供模型

Model: cross-encoder/ms-marco-TinyBERT-L6
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-05-13 16:14:35 +08:00
commit c692b1463f
25 changed files with 79699 additions and 0 deletions

10
.gitattributes vendored Normal file
View File

@@ -0,0 +1,10 @@
*.bin.* filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tar.gz filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
model.safetensors filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,127 @@
epoch,steps,MRR@10
0,5000,0.5650095238095237
0,10000,0.5849968253968254
0,15000,0.6097650793650794
0,20000,0.6246285714285715
0,25000,0.6100253968253967
0,30000,0.6270730158730159
0,35000,0.6138888888888888
0,40000,0.6240317460317462
0,45000,0.6327619047619049
0,50000,0.619631746031746
0,55000,0.5871142857142856
0,60000,0.6175809523809525
0,65000,0.6081968253968254
0,70000,0.6151301587301587
0,75000,0.6093269841269842
0,80000,0.6032571428571428
0,85000,0.6138063492063491
0,90000,0.6156952380952381
0,95000,0.6303523809523809
0,100000,0.6061523809523809
0,105000,0.6133174603174603
0,110000,0.6226063492063493
0,115000,0.6176349206349206
0,120000,0.6104761904761905
0,125000,0.6332253968253967
0,130000,0.6289523809523808
0,135000,0.6181809523809524
0,140000,0.6399841269841271
0,145000,0.623073015873016
0,150000,0.5963587301587302
0,155000,0.6157301587301588
0,160000,0.613120634920635
0,165000,0.6089936507936508
0,170000,0.6203301587301587
0,175000,0.6171269841269841
0,180000,0.5939269841269841
0,185000,0.6417873015873015
0,190000,0.6164476190476191
0,195000,0.6215841269841269
0,200000,0.6298984126984126
0,205000,0.6030507936507936
0,210000,0.6084730158730158
0,215000,0.6092730158730159
0,220000,0.5939650793650793
0,225000,0.6124190476190475
0,230000,0.6039269841269841
0,235000,0.6253301587301587
0,240000,0.634904761904762
0,245000,0.6317015873015873
0,250000,0.6196603174603175
0,255000,0.6287396825396825
0,260000,0.6095746031746031
0,265000,0.6263492063492063
0,270000,0.6171079365079365
0,275000,0.6289523809523809
0,280000,0.6202634920634921
0,285000,0.6255301587301587
0,290000,0.5993841269841268
0,295000,0.6191841269841271
0,300000,0.6203396825396825
0,305000,0.6128412698412699
0,310000,0.6090825396825398
0,315000,0.5950539682539682
0,320000,0.5990444444444444
0,325000,0.6042412698412698
0,330000,0.5960190476190476
0,335000,0.6106222222222223
0,340000,0.6055968253968255
0,345000,0.5984095238095238
0,350000,0.6142984126984128
0,355000,0.6137746031746032
0,360000,0.6018412698412698
0,365000,0.6123079365079365
0,370000,0.6130285714285715
0,375000,0.6008412698412698
0,380000,0.6020698412698412
0,385000,0.6100222222222222
0,390000,0.5971650793650793
0,395000,0.5941968253968255
0,400000,0.5871428571428571
0,405000,0.6100190476190476
0,410000,0.5903174603174602
0,415000,0.5988317460317459
0,420000,0.6132380952380952
0,425000,0.6144412698412698
0,430000,0.5980888888888888
0,435000,0.5973746031746032
0,440000,0.595384126984127
0,445000,0.5871714285714286
0,450000,0.6012412698412699
0,455000,0.5873047619047618
0,460000,0.595584126984127
0,465000,0.5804285714285713
0,470000,0.5887619047619047
0,475000,0.5872761904761904
0,480000,0.5871396825396825
0,485000,0.5907174603174602
0,490000,0.5880412698412699
0,495000,0.5807968253968254
0,500000,0.5909746031746032
0,505000,0.5912984126984128
0,510000,0.5942761904761905
0,515000,0.5840222222222223
0,520000,0.5852380952380952
0,525000,0.582784126984127
0,530000,0.5916190476190476
0,535000,0.5777269841269841
0,540000,0.582120634920635
0,545000,0.5746634920634921
0,550000,0.5746444444444445
0,555000,0.5632444444444444
0,560000,0.5799650793650795
0,565000,0.5932507936507936
0,570000,0.5816190476190476
0,575000,0.5838857142857143
0,580000,0.5859650793650794
0,585000,0.5843968253968255
0,590000,0.5840634920634921
0,595000,0.5958285714285714
0,600000,0.5842857142857142
0,605000,0.5892507936507937
0,610000,0.5914507936507937
0,615000,0.5953968253968254
0,620000,0.5925174603174603
0,625000,0.5890857142857143
0,-1,0.5890857142857143
1 epoch steps MRR@10
2 0 5000 0.5650095238095237
3 0 10000 0.5849968253968254
4 0 15000 0.6097650793650794
5 0 20000 0.6246285714285715
6 0 25000 0.6100253968253967
7 0 30000 0.6270730158730159
8 0 35000 0.6138888888888888
9 0 40000 0.6240317460317462
10 0 45000 0.6327619047619049
11 0 50000 0.619631746031746
12 0 55000 0.5871142857142856
13 0 60000 0.6175809523809525
14 0 65000 0.6081968253968254
15 0 70000 0.6151301587301587
16 0 75000 0.6093269841269842
17 0 80000 0.6032571428571428
18 0 85000 0.6138063492063491
19 0 90000 0.6156952380952381
20 0 95000 0.6303523809523809
21 0 100000 0.6061523809523809
22 0 105000 0.6133174603174603
23 0 110000 0.6226063492063493
24 0 115000 0.6176349206349206
25 0 120000 0.6104761904761905
26 0 125000 0.6332253968253967
27 0 130000 0.6289523809523808
28 0 135000 0.6181809523809524
29 0 140000 0.6399841269841271
30 0 145000 0.623073015873016
31 0 150000 0.5963587301587302
32 0 155000 0.6157301587301588
33 0 160000 0.613120634920635
34 0 165000 0.6089936507936508
35 0 170000 0.6203301587301587
36 0 175000 0.6171269841269841
37 0 180000 0.5939269841269841
38 0 185000 0.6417873015873015
39 0 190000 0.6164476190476191
40 0 195000 0.6215841269841269
41 0 200000 0.6298984126984126
42 0 205000 0.6030507936507936
43 0 210000 0.6084730158730158
44 0 215000 0.6092730158730159
45 0 220000 0.5939650793650793
46 0 225000 0.6124190476190475
47 0 230000 0.6039269841269841
48 0 235000 0.6253301587301587
49 0 240000 0.634904761904762
50 0 245000 0.6317015873015873
51 0 250000 0.6196603174603175
52 0 255000 0.6287396825396825
53 0 260000 0.6095746031746031
54 0 265000 0.6263492063492063
55 0 270000 0.6171079365079365
56 0 275000 0.6289523809523809
57 0 280000 0.6202634920634921
58 0 285000 0.6255301587301587
59 0 290000 0.5993841269841268
60 0 295000 0.6191841269841271
61 0 300000 0.6203396825396825
62 0 305000 0.6128412698412699
63 0 310000 0.6090825396825398
64 0 315000 0.5950539682539682
65 0 320000 0.5990444444444444
66 0 325000 0.6042412698412698
67 0 330000 0.5960190476190476
68 0 335000 0.6106222222222223
69 0 340000 0.6055968253968255
70 0 345000 0.5984095238095238
71 0 350000 0.6142984126984128
72 0 355000 0.6137746031746032
73 0 360000 0.6018412698412698
74 0 365000 0.6123079365079365
75 0 370000 0.6130285714285715
76 0 375000 0.6008412698412698
77 0 380000 0.6020698412698412
78 0 385000 0.6100222222222222
79 0 390000 0.5971650793650793
80 0 395000 0.5941968253968255
81 0 400000 0.5871428571428571
82 0 405000 0.6100190476190476
83 0 410000 0.5903174603174602
84 0 415000 0.5988317460317459
85 0 420000 0.6132380952380952
86 0 425000 0.6144412698412698
87 0 430000 0.5980888888888888
88 0 435000 0.5973746031746032
89 0 440000 0.595384126984127
90 0 445000 0.5871714285714286
91 0 450000 0.6012412698412699
92 0 455000 0.5873047619047618
93 0 460000 0.595584126984127
94 0 465000 0.5804285714285713
95 0 470000 0.5887619047619047
96 0 475000 0.5872761904761904
97 0 480000 0.5871396825396825
98 0 485000 0.5907174603174602
99 0 490000 0.5880412698412699
100 0 495000 0.5807968253968254
101 0 500000 0.5909746031746032
102 0 505000 0.5912984126984128
103 0 510000 0.5942761904761905
104 0 515000 0.5840222222222223
105 0 520000 0.5852380952380952
106 0 525000 0.582784126984127
107 0 530000 0.5916190476190476
108 0 535000 0.5777269841269841
109 0 540000 0.582120634920635
110 0 545000 0.5746634920634921
111 0 550000 0.5746444444444445
112 0 555000 0.5632444444444444
113 0 560000 0.5799650793650795
114 0 565000 0.5932507936507936
115 0 570000 0.5816190476190476
116 0 575000 0.5838857142857143
117 0 580000 0.5859650793650794
118 0 585000 0.5843968253968255
119 0 590000 0.5840634920634921
120 0 595000 0.5958285714285714
121 0 600000 0.5842857142857142
122 0 605000 0.5892507936507937
123 0 610000 0.5914507936507937
124 0 615000 0.5953968253968254
125 0 620000 0.5925174603174603
126 0 625000 0.5890857142857143
127 0 -1 0.5890857142857143

75
README.md Normal file
View File

@@ -0,0 +1,75 @@
---
license: apache-2.0
datasets:
- sentence-transformers/msmarco
language:
- en
base_model:
- nreimers/TinyBERT_L-6_H-768_v2
pipeline_tag: text-ranking
library_name: sentence-transformers
tags:
- transformers
---
# Cross-Encoder for MS Marco
This model was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/cross_encoder/training/ms_marco)
## Usage with Transformers
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-TinyBERT-L6')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-TinyBERT-L6')
features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
model.eval()
with torch.no_grad():
scores = model(**features).logits
print(scores)
```
## Usage with SentenceTransformers
The usage becomes easier when you have [SentenceTransformers](https://www.sbert.net/) installed. Then, you can use the pre-trained models like this:
```python
from sentence_transformers import CrossEncoder
model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L6', max_length=512)
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
```
## Performance
In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
| Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec |
| ------------- |:-------------| -----| --- |
| **Version 2 models** | | |
| cross-encoder/ms-marco-TinyBERT-L2-v2 | 69.84 | 32.56 | 9000
| cross-encoder/ms-marco-MiniLM-L2-v2 | 71.01 | 34.85 | 4100
| cross-encoder/ms-marco-MiniLM-L4-v2 | 73.04 | 37.70 | 2500
| cross-encoder/ms-marco-MiniLM-L6-v2 | 74.30 | 39.01 | 1800
| cross-encoder/ms-marco-MiniLM-L12-v2 | 74.31 | 39.02 | 960
| **Version 1 models** | | |
| cross-encoder/ms-marco-TinyBERT-L2 | 67.43 | 30.15 | 9000
| cross-encoder/ms-marco-TinyBERT-L4 | 68.09 | 34.50 | 2900
| cross-encoder/ms-marco-TinyBERT-L6 | 69.57 | 36.13 | 680
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340
| **Other models** | | |
| nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900
| nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340
| nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100
| Capreolus/electra-base-msmarco | 71.23 | 36.89 | 340
| amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | 35.54 | 330
| sebastian-hofstaetter/distilbert-cat-margin_mse-T2-msmarco | 72.82 | 37.88 | 720
Note: Runtime was computed on a V100 GPU.

27
config.json Normal file
View File

@@ -0,0 +1,27 @@
{
"_name_or_path": "nreimers/TinyBERT_L-6_H-768_v2",
"architectures": [
"BertForSequenceClassification"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"id2label": {
"0": "LABEL_0"
},
"initializer_range": 0.02,
"intermediate_size": 3072,
"label2id": {
"LABEL_0": 0
},
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 6,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30522
}

3
flax_model.msgpack Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8dca17d03d10927dc2cae61af57f617da0cbeb4cf86318fa6e4f866ebe177359
size 267826914

3
model.safetensors Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:44e975b3a25903d3d9cefbddc06505a0d6cc2eb8ff311f0da6ace9118b366753
size 267839500

3
onnx/model.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:76285b39ad4f9fe2638aeb4d4373e839644b6a46c57a94e7df34aa2f1d5910a5
size 267979984

3
onnx/model_O1.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:41d359e2246be0906dd9d02ff71129f332e4814d9d7eb27752da54d6de3a1ca8
size 267931340

3
onnx/model_O2.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:58c001c932ae5f292b7b00069bb30eae34f9e0418a1c91b3e325d4716f148c6d
size 267846024

3
onnx/model_O3.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ac0afd86cd65c5f238591501b55a156d6ab69979236ca1f44c9e501864dff16b
size 267845955

3
onnx/model_O4.onnx Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a36dfbc48c898574fd6e0cae98a4ae3ca2e049669e783805180fd6d19fd8934a
size 134000992

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7bf5d227ffe3b070d4704a1646613ddf6df77a120ea683369583e144fcf4f12e
size 67641829

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7bf5d227ffe3b070d4704a1646613ddf6df77a120ea683369583e144fcf4f12e
size 67641829

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7bf5d227ffe3b070d4704a1646613ddf6df77a120ea683369583e144fcf4f12e
size 67641829

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6ab7d03ca90b69264bdaa5014ce9d8beef6df39afd3eccd2734f75bd251299bd
size 67641828

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2bfbc62d0401ad098b9ab737425ad055ebd0d9fda68a7da4847c1e992c4afddb
size 267827364

6647
openvino/openvino_model.xml Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d57aa37147cef8bca75421c893d64a50eb6a3fba7b2151d3c05166908f984fd3
size 67575240

File diff suppressed because it is too large Load Diff

3
pytorch_model.bin Normal file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0cf70691b32e33dc1a57845ceeeb81bc70cfec62a0d584d063f13576403f2759
size 267871721

7
special_tokens_map.json Normal file
View File

@@ -0,0 +1,7 @@
{
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"unk_token": "[UNK]"
}

30672
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

58
tokenizer_config.json Normal file
View File

@@ -0,0 +1,58 @@
{
"added_tokens_decoder": {
"0": {
"content": "[PAD]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"100": {
"content": "[UNK]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"101": {
"content": "[CLS]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"102": {
"content": "[SEP]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"103": {
"content": "[MASK]",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"clean_up_tokenization_spaces": true,
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": true,
"extra_special_tokens": {},
"mask_token": "[MASK]",
"model_max_length": 512,
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]"
}

193
train_script.py Normal file
View File

@@ -0,0 +1,193 @@
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
import gzip
import sys
import numpy as np
import os
from shutil import copyfile
import csv
import tqdm
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
#Define our Cross-Encoder
model_name = sys.argv[1] #'google/electra-small-discriminator'
train_batch_size = 32
num_epochs = 1
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#We set num_labels=1, which predicts a continous score between 0 and 1
model = CrossEncoder(model_name, num_labels=1, max_length=512)
# Write self to path
os.makedirs(model_save_path, exist_ok=True)
train_script_path = os.path.join(model_save_path, 'train_script.py')
copyfile(__file__, train_script_path)
with open(train_script_path, 'a') as fOut:
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
corpus = {}
queries = {}
#### Read train file
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
for line in fIn:
pid, passage = line.strip().split("\t")
corpus[pid] = passage
with open('../data/queries.train.tsv', 'r') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
queries[qid] = query
pos_neg_ration = (4+1)
cnt = 0
train_samples = []
dev_samples = {}
num_dev_queries = 125
num_max_dev_negatives = 200
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
for line in fIn:
qid, pos_id, neg_id = line.strip().split()
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
if qid in dev_samples:
dev_samples[qid]['positive'].add(corpus[pos_id])
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
dev_samples[qid]['negative'].add(corpus[neg_id])
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
for line in tqdm.tqdm(fIn, unit_scale=True):
cnt += 1
qid, pos_id, neg_id = line.strip().split()
query = queries[qid]
if (cnt % pos_neg_ration) == 0:
passage = corpus[pos_id]
label = 1
else:
passage = corpus[neg_id]
label = 0
train_samples.append(InputExample(texts=[query, passage], label=label))
if len(train_samples) >= 2e7:
break
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
# We add an evaluator, which evaluates the performance during training
class CERerankingEvaluator:
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
self.samples = samples
self.name = name
self.mrr_at_k = mrr_at_k
if isinstance(self.samples, dict):
self.samples = list(self.samples.values())
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = " after epoch {}:".format(epoch)
else:
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
all_mrr_scores = []
num_queries = 0
num_positives = []
num_negatives = []
for instance in self.samples:
query = instance['query']
positive = list(instance['positive'])
negative = list(instance['negative'])
docs = positive + negative
is_relevant = [True]*len(positive) + [False]*len(negative)
if len(positive) == 0 or len(negative) == 0:
continue
num_queries += 1
num_positives.append(len(positive))
num_negatives.append(len(negative))
model_input = [[query, doc] for doc in docs]
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
mrr_score = 0
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
if is_relevant[index]:
mrr_score = 1 / (rank+1)
all_mrr_scores.append(mrr_score)
mean_mrr = np.mean(all_mrr_scores)
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
if output_path is not None:
csv_path = os.path.join(output_path, self.csv_file)
output_file_exists = os.path.isfile(csv_path)
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
writer = csv.writer(f)
if not output_file_exists:
writer.writerow(self.csv_headers)
writer.writerow([epoch, steps, mean_mrr])
return mean_mrr
evaluator = CERerankingEvaluator(dev_samples)
# Configure the training
warmup_steps = 5000
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=5000,
warmup_steps=warmup_steps,
output_path=model_save_path,
use_amp=True)
#Save latest model
model.save(model_save_path+'-latest')
# Script was called via:
#python train_cross-encoder.py nreimers/TinyBERT_L-6_H-768_v2

30522
vocab.txt Normal file

File diff suppressed because it is too large Load Diff