init
This commit is contained in:
162
transformers/examples/legacy/token-classification/tasks.py
Normal file
162
transformers/examples/legacy/token-classification/tasks.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import TextIO, Union
|
||||
|
||||
from conllu import parse_incr
|
||||
from utils_ner import InputExample, Split, TokenClassificationTask
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NER(TokenClassificationTask):
|
||||
def __init__(self, label_idx=-1):
|
||||
# in NER datasets, the last column is usually reserved for NER label
|
||||
self.label_idx = label_idx
|
||||
|
||||
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> list[InputExample]:
|
||||
if isinstance(mode, Split):
|
||||
mode = mode.value
|
||||
file_path = os.path.join(data_dir, f"{mode}.txt")
|
||||
guid_index = 1
|
||||
examples = []
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
words = []
|
||||
labels = []
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
if words:
|
||||
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
||||
guid_index += 1
|
||||
words = []
|
||||
labels = []
|
||||
else:
|
||||
splits = line.split(" ")
|
||||
words.append(splits[0])
|
||||
if len(splits) > 1:
|
||||
labels.append(splits[self.label_idx].replace("\n", ""))
|
||||
else:
|
||||
# Examples could have no label for mode = "test"
|
||||
labels.append("O")
|
||||
if words:
|
||||
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
||||
return examples
|
||||
|
||||
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: list):
|
||||
example_id = 0
|
||||
for line in test_input_reader:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
writer.write(line)
|
||||
if not preds_list[example_id]:
|
||||
example_id += 1
|
||||
elif preds_list[example_id]:
|
||||
output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
|
||||
writer.write(output_line)
|
||||
else:
|
||||
logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
||||
|
||||
def get_labels(self, path: str) -> list[str]:
|
||||
if path:
|
||||
with open(path) as f:
|
||||
labels = f.read().splitlines()
|
||||
if "O" not in labels:
|
||||
labels = ["O"] + labels
|
||||
return labels
|
||||
else:
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
|
||||
|
||||
|
||||
class Chunk(NER):
|
||||
def __init__(self):
|
||||
# in CONLL2003 dataset chunk column is second-to-last
|
||||
super().__init__(label_idx=-2)
|
||||
|
||||
def get_labels(self, path: str) -> list[str]:
|
||||
if path:
|
||||
with open(path) as f:
|
||||
labels = f.read().splitlines()
|
||||
if "O" not in labels:
|
||||
labels = ["O"] + labels
|
||||
return labels
|
||||
else:
|
||||
return [
|
||||
"O",
|
||||
"B-ADVP",
|
||||
"B-INTJ",
|
||||
"B-LST",
|
||||
"B-PRT",
|
||||
"B-NP",
|
||||
"B-SBAR",
|
||||
"B-VP",
|
||||
"B-ADJP",
|
||||
"B-CONJP",
|
||||
"B-PP",
|
||||
"I-ADVP",
|
||||
"I-INTJ",
|
||||
"I-LST",
|
||||
"I-PRT",
|
||||
"I-NP",
|
||||
"I-SBAR",
|
||||
"I-VP",
|
||||
"I-ADJP",
|
||||
"I-CONJP",
|
||||
"I-PP",
|
||||
]
|
||||
|
||||
|
||||
class POS(TokenClassificationTask):
|
||||
def read_examples_from_file(self, data_dir, mode: Union[Split, str]) -> list[InputExample]:
|
||||
if isinstance(mode, Split):
|
||||
mode = mode.value
|
||||
file_path = os.path.join(data_dir, f"{mode}.txt")
|
||||
guid_index = 1
|
||||
examples = []
|
||||
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
for sentence in parse_incr(f):
|
||||
words = []
|
||||
labels = []
|
||||
for token in sentence:
|
||||
words.append(token["form"])
|
||||
labels.append(token["upos"])
|
||||
assert len(words) == len(labels)
|
||||
if words:
|
||||
examples.append(InputExample(guid=f"{mode}-{guid_index}", words=words, labels=labels))
|
||||
guid_index += 1
|
||||
return examples
|
||||
|
||||
def write_predictions_to_file(self, writer: TextIO, test_input_reader: TextIO, preds_list: list):
|
||||
example_id = 0
|
||||
for sentence in parse_incr(test_input_reader):
|
||||
s_p = preds_list[example_id]
|
||||
out = ""
|
||||
for token in sentence:
|
||||
out += f"{token['form']} ({token['upos']}|{s_p.pop(0)}) "
|
||||
out += "\n"
|
||||
writer.write(out)
|
||||
example_id += 1
|
||||
|
||||
def get_labels(self, path: str) -> list[str]:
|
||||
if path:
|
||||
with open(path) as f:
|
||||
return f.read().splitlines()
|
||||
else:
|
||||
return [
|
||||
"ADJ",
|
||||
"ADP",
|
||||
"ADV",
|
||||
"AUX",
|
||||
"CCONJ",
|
||||
"DET",
|
||||
"INTJ",
|
||||
"NOUN",
|
||||
"NUM",
|
||||
"PART",
|
||||
"PRON",
|
||||
"PROPN",
|
||||
"PUNCT",
|
||||
"SCONJ",
|
||||
"SYM",
|
||||
"VERB",
|
||||
"X",
|
||||
]
|
||||
Reference in New Issue
Block a user