First commit
This commit is contained in:
46
pkgs/xformers/benchmarks/LRA/code/dataset.py
Normal file
46
pkgs/xformers/benchmarks/LRA/code/dataset.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: Almost as-is from the Nystromformer repo
|
||||
# https://github.com/mlpen/Nystromformer
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
|
||||
class LRADataset(Dataset):
|
||||
def __init__(self, file_path, seq_len):
|
||||
with open(file_path, "rb") as f:
|
||||
self.examples = pickle.load(f)
|
||||
|
||||
self.seq_len = seq_len
|
||||
logging.info(f"Loaded {file_path}... size={len(self.examples)}")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.create_inst(self.examples[i], self.seq_len)
|
||||
|
||||
@staticmethod
|
||||
def create_inst(inst, seq_len):
|
||||
output = {
|
||||
"input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len]
|
||||
}
|
||||
output["mask_0"] = (output["input_ids_0"] != 0).float()
|
||||
|
||||
if "input_ids_1" in inst:
|
||||
output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[
|
||||
:seq_len
|
||||
]
|
||||
output["mask_1"] = (output["input_ids_1"] != 0).float()
|
||||
output["label"] = torch.tensor(inst["label"], dtype=torch.long)
|
||||
return output
|
||||
Reference in New Issue
Block a user