47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
# CREDITS: Almost as-is from the Nystromformer repo
|
|
# https://github.com/mlpen/Nystromformer
|
|
|
|
import logging
|
|
import pickle
|
|
|
|
import torch
|
|
from torch.utils.data.dataset import Dataset
|
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
|
|
class LRADataset(Dataset):
|
|
def __init__(self, file_path, seq_len):
|
|
with open(file_path, "rb") as f:
|
|
self.examples = pickle.load(f)
|
|
|
|
self.seq_len = seq_len
|
|
logging.info(f"Loaded {file_path}... size={len(self.examples)}")
|
|
|
|
def __len__(self):
|
|
return len(self.examples)
|
|
|
|
def __getitem__(self, i):
|
|
return self.create_inst(self.examples[i], self.seq_len)
|
|
|
|
@staticmethod
|
|
def create_inst(inst, seq_len):
|
|
output = {
|
|
"input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len]
|
|
}
|
|
output["mask_0"] = (output["input_ids_0"] != 0).float()
|
|
|
|
if "input_ids_1" in inst:
|
|
output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[
|
|
:seq_len
|
|
]
|
|
output["mask_1"] = (output["input_ids_1"] != 0).float()
|
|
output["label"] = torch.tensor(inst["label"], dtype=torch.long)
|
|
return output
|