Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/LRA/code/dataset.py
2025-08-05 19:02:46 +08:00

47 lines
1.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# CREDITS: Almost as-is from the Nystromformer repo
# https://github.com/mlpen/Nystromformer
import logging
import pickle
import torch
from torch.utils.data.dataset import Dataset
logging.getLogger().setLevel(logging.INFO)
class LRADataset(Dataset):
def __init__(self, file_path, seq_len):
with open(file_path, "rb") as f:
self.examples = pickle.load(f)
self.seq_len = seq_len
logging.info(f"Loaded {file_path}... size={len(self.examples)}")
def __len__(self):
return len(self.examples)
def __getitem__(self, i):
return self.create_inst(self.examples[i], self.seq_len)
@staticmethod
def create_inst(inst, seq_len):
output = {
"input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len]
}
output["mask_0"] = (output["input_ids_0"] != 0).float()
if "input_ids_1" in inst:
output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[
:seq_len
]
output["mask_1"] = (output["input_ids_1"] != 0).float()
output["label"] = torch.tensor(inst["label"], dtype=torch.long)
return output