init
This commit is contained in:
109
transformers/examples/legacy/seq2seq/old_test_calculate_rouge.py
Normal file
109
transformers/examples/legacy/seq2seq/old_test_calculate_rouge.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from rouge_cli import calculate_rouge_path
|
||||
|
||||
from utils import calculate_rouge
|
||||
|
||||
|
||||
PRED = [
|
||||
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the'
|
||||
' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe'
|
||||
" depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.",
|
||||
"The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal"
|
||||
" accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's"
|
||||
" founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the"
|
||||
" body.",
|
||||
"Amnesty International releases its annual report on the death penalty. The report catalogs the use of"
|
||||
" state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the"
|
||||
" world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital"
|
||||
" punishment.",
|
||||
]
|
||||
|
||||
TGT = [
|
||||
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .'
|
||||
' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz'
|
||||
" had informed his Lufthansa training school of an episode of severe depression, airline says .",
|
||||
"Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ."
|
||||
" Israel and the United States opposed the move, which could open the door to war crimes investigations against"
|
||||
" Israelis .",
|
||||
"Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to"
|
||||
" death . Organization claims that governments around the world are using the threat of terrorism to advance"
|
||||
" executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death"
|
||||
" sentences up by 28% .",
|
||||
]
|
||||
|
||||
|
||||
def test_disaggregated_scores_are_determinstic():
|
||||
no_aggregation = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2", "rougeL"])
|
||||
assert isinstance(no_aggregation, defaultdict)
|
||||
no_aggregation_just_r2 = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2"])
|
||||
assert (
|
||||
pd.DataFrame(no_aggregation["rouge2"]).fmeasure.mean()
|
||||
== pd.DataFrame(no_aggregation_just_r2["rouge2"]).fmeasure.mean()
|
||||
)
|
||||
|
||||
|
||||
def test_newline_cnn_improvement():
|
||||
k = "rougeLsum"
|
||||
score = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=[k])[k]
|
||||
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=[k])[k]
|
||||
assert score > score_no_sep
|
||||
|
||||
|
||||
def test_newline_irrelevant_for_other_metrics():
|
||||
k = ["rouge1", "rouge2", "rougeL"]
|
||||
score_sep = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=k)
|
||||
score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=k)
|
||||
assert score_sep == score_no_sep
|
||||
|
||||
|
||||
def test_single_sent_scores_dont_depend_on_newline_sep():
|
||||
pred = [
|
||||
"Her older sister, Margot Frank, died in 1945, a month earlier than previously thought.",
|
||||
'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .',
|
||||
]
|
||||
tgt = [
|
||||
"Margot Frank, died in 1945, a month earlier than previously thought.",
|
||||
'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of'
|
||||
" the final seconds on board Flight 9525.",
|
||||
]
|
||||
assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False)
|
||||
|
||||
|
||||
def test_pegasus_newline():
|
||||
pred = [
|
||||
"""" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """
|
||||
]
|
||||
tgt = [
|
||||
""" Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says ."""
|
||||
]
|
||||
|
||||
prev_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"], newline_sep=False)["rougeLsum"]
|
||||
new_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"])["rougeLsum"]
|
||||
assert new_score > prev_score
|
||||
|
||||
|
||||
def test_rouge_cli():
|
||||
data_dir = Path("examples/seq2seq/test_data/wmt_en_ro")
|
||||
metrics = calculate_rouge_path(data_dir.joinpath("test.source"), data_dir.joinpath("test.target"))
|
||||
assert isinstance(metrics, dict)
|
||||
metrics_default_dict = calculate_rouge_path(
|
||||
data_dir.joinpath("test.source"), data_dir.joinpath("test.target"), bootstrap_aggregation=False
|
||||
)
|
||||
assert isinstance(metrics_default_dict, defaultdict)
|
||||
Reference in New Issue
Block a user