init
This commit is contained in:
87
transformers/examples/legacy/seq2seq/pack_dataset.py
Executable file
87
transformers/examples/legacy/seq2seq/pack_dataset.py
Executable file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Fill examples with bitext up to max_tokens without breaking up examples.
|
||||
[['I went', 'yo fui'],
|
||||
['to the store', 'a la tienda']
|
||||
]
|
||||
=> ['I went to the store', 'yo fui a la tienda']
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def pack_examples(tok, src_examples, tgt_examples, max_tokens=1024):
|
||||
finished_src, finished_tgt = [], []
|
||||
|
||||
sorted_examples = list(zip(src_examples, tgt_examples))
|
||||
new_src, new_tgt = sorted_examples[0]
|
||||
|
||||
def is_too_big(strang):
|
||||
return tok(strang, return_tensors="pt").input_ids.shape[1] > max_tokens
|
||||
|
||||
for src, tgt in tqdm(sorted_examples[1:]):
|
||||
cand_src = new_src + " " + src
|
||||
cand_tgt = new_tgt + " " + tgt
|
||||
if is_too_big(cand_src) or is_too_big(cand_tgt): # can't fit, finalize example
|
||||
finished_src.append(new_src)
|
||||
finished_tgt.append(new_tgt)
|
||||
new_src, new_tgt = src, tgt
|
||||
else: # can fit, keep adding
|
||||
new_src, new_tgt = cand_src, cand_tgt
|
||||
|
||||
# cleanup
|
||||
if new_src:
|
||||
assert new_tgt
|
||||
finished_src.append(new_src)
|
||||
finished_tgt.append(new_tgt)
|
||||
return finished_src, finished_tgt
|
||||
|
||||
|
||||
def pack_data_dir(tok, data_dir: Path, max_tokens, save_path):
|
||||
save_path = Path(save_path)
|
||||
save_path.mkdir(exist_ok=True)
|
||||
for split in ["train"]:
|
||||
src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target"
|
||||
src_docs = [x.rstrip() for x in Path(src_path).open().readlines()]
|
||||
tgt_docs = [x.rstrip() for x in Path(tgt_path).open().readlines()]
|
||||
packed_src, packed_tgt = pack_examples(tok, src_docs, tgt_docs, max_tokens)
|
||||
print(f"packed {split} split from {len(src_docs)} examples -> {len(packed_src)}.")
|
||||
Path(save_path / f"{split}.source").open("w").write("\n".join(packed_src))
|
||||
Path(save_path / f"{split}.target").open("w").write("\n".join(packed_tgt))
|
||||
for split in ["val", "test"]:
|
||||
src_path, tgt_path = data_dir / f"{split}.source", data_dir / f"{split}.target"
|
||||
shutil.copyfile(src_path, save_path / f"{split}.source")
|
||||
shutil.copyfile(tgt_path, save_path / f"{split}.target")
|
||||
|
||||
|
||||
def packer_cli():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tok_name", type=str, help="like facebook/bart-large-cnn,google-t5/t5-base, etc.")
|
||||
parser.add_argument("--max_seq_len", type=int, default=128)
|
||||
parser.add_argument("--data_dir", type=str)
|
||||
parser.add_argument("--save_path", type=str)
|
||||
args = parser.parse_args()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tok_name)
|
||||
return pack_data_dir(tokenizer, Path(args.data_dir), args.max_seq_len, args.save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
packer_cli()
|
||||
Reference in New Issue
Block a user