437 lines
16 KiB
Python
437 lines
16 KiB
Python
import dataclasses
|
|
import pprint
|
|
import time
|
|
from functools import partial
|
|
import json
|
|
import base64
|
|
from multiprocessing import Pool
|
|
|
|
import h5py
|
|
import mlxu
|
|
from ml_collections.config_dict import config_dict
|
|
from ml_collections import ConfigDict
|
|
from tqdm import tqdm, trange
|
|
import numpy as np
|
|
|
|
from datasets import load_dataset, load_from_disk
|
|
|
|
|
|
class DatasetFactory(object):
|
|
""" Datset builder class. """
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.type = 'huggingface'
|
|
config.text_processor = TextProcessor.get_default_config()
|
|
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
|
|
config.json_dataset = JsonDataset.get_default_config()
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
@classmethod
|
|
def load_dataset(cls, config, tokenizer, **kwargs):
|
|
config = cls.get_default_config(config)
|
|
text_processor = TextProcessor(config.text_processor, tokenizer)
|
|
if config.type == 'huggingface':
|
|
return HuggingfaceDataset(
|
|
config.huggingface_dataset, tokenizer, text_processor, **kwargs
|
|
)
|
|
elif config.type == 'json':
|
|
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
|
|
else:
|
|
raise ValueError(f'Unknown dataset type: {config.type}')
|
|
|
|
def __init__(self):
|
|
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
|
|
|
|
|
|
class TextProcessor(object):
|
|
""" Example processor that converts a dictionary of texts into tokens. """
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.fields_from_example = ''
|
|
config.fields = ''
|
|
config.subfield_separator = ' '
|
|
config.add_bos_token = True
|
|
config.add_eos_token = True
|
|
config.prepend_text = ''
|
|
config.base64_token_dtype = 'i4'
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
def __init__(self, config, tokenizer):
|
|
self.config = self.get_default_config(config)
|
|
assert self.config.fields != '' or self.config.fields_from_example != '', (
|
|
'Either fields or fields_from_example must be specified.'
|
|
)
|
|
self.tokenizer = tokenizer
|
|
|
|
def __call__(self, example, has_aux=False):
|
|
if has_aux:
|
|
example, *aux = example
|
|
else:
|
|
aux = tuple()
|
|
token_buffer = []
|
|
loss_mask_buffer = []
|
|
|
|
if self.config.add_bos_token:
|
|
token_buffer.append(self.tokenizer.bos_token_id)
|
|
loss_mask_buffer.append(0.0)
|
|
|
|
if self.config.fields_from_example != '':
|
|
fields = example[self.config.fields_from_example].split(',')
|
|
else:
|
|
fields = self.config.fields.split(',')
|
|
|
|
for i, field in enumerate(fields):
|
|
if field.startswith('[') and field.endswith(']'):
|
|
# No loss for this field.
|
|
field = field[1:-1]
|
|
mask = 0.0
|
|
else:
|
|
mask = 1.0
|
|
|
|
if field.startswith('<|') and field.endswith('|>'):
|
|
# Special tokens.
|
|
field = field[2:-2]
|
|
if field == 'bos':
|
|
token_buffer.append(self.tokenizer.bos_token_id)
|
|
elif field == 'eos':
|
|
token_buffer.append(self.tokenizer.eos_token_id)
|
|
else:
|
|
# Token ID specified directly.
|
|
token_buffer.append(int(field))
|
|
loss_mask_buffer.append(mask)
|
|
elif field.startswith('{') and field.endswith('}'):
|
|
field = field[1:-1]
|
|
# Base64 encoded raw tokens.
|
|
tokens = np.frombuffer(
|
|
base64.b64decode(example[field]),
|
|
dtype=self.config.base64_token_dtype
|
|
).tolist()
|
|
token_buffer.extend(tokens)
|
|
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
|
else:
|
|
subfields = field.split('+')
|
|
text = self.config.subfield_separator.join(
|
|
[example[subfield] for subfield in subfields]
|
|
)
|
|
if i == 0:
|
|
text = self.config.prepend_text + text
|
|
tokens = self.tokenizer.encode(text)
|
|
token_buffer.extend(tokens)
|
|
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
|
|
|
if self.config.add_eos_token:
|
|
token_buffer.append(self.tokenizer.eos_token_id)
|
|
loss_mask_buffer.append(1.0)
|
|
|
|
return token_buffer, loss_mask_buffer, *aux
|
|
|
|
|
|
class HuggingfaceDataset(object):
|
|
""" Huggingface dataset, where the dataset is loaded using the huggingface
|
|
datasets.load_dataset() function.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.path = 'c4'
|
|
config.name = 'en'
|
|
config.split = 'train'
|
|
config.streaming = False
|
|
config.seq_length = 1024
|
|
config.batch_size = 8
|
|
config.always_start_with_bos = False
|
|
config.start_seek_loc = 0
|
|
config.tokens_count_at_start = 0
|
|
config.batch_token_dtype = 'i4'
|
|
config.reset_dataset_loc = False
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
|
|
self.config = self.get_default_config(config)
|
|
name = self.config.name if self.config.name != '' else None
|
|
split = self.config.split if self.config.split != '' else None
|
|
self._tokenizer = tokenizer
|
|
self._text_processor = text_processor
|
|
self._dataset = load_from_disk(
|
|
self.config.path
|
|
)[split]
|
|
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
|
|
self._eval_dataset = eval_dataset
|
|
self._train_epochs = 0
|
|
self._dataset_loc = self.config.start_seek_loc
|
|
self._total_tokens = self.config.tokens_count_at_start
|
|
self._index = 0
|
|
self.reset_dataset_loc = self.config.reset_dataset_loc
|
|
|
|
|
|
def __iter__(self):
|
|
if not self._eval_dataset and self._train_epochs > 0:
|
|
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
|
chunk_size = self.config.batch_size * self.config.seq_length
|
|
while True:
|
|
token_buffer = []
|
|
loss_mask_buffer = []
|
|
if not self._eval_dataset and self._train_epochs > 0:
|
|
self._dataset.set_epoch(self._train_epochs)
|
|
for index, example in enumerate(self._dataset):
|
|
self._index = index
|
|
if not self._eval_dataset and self._dataset_loc > index:
|
|
continue
|
|
tokens, loss_masks = self.text_processor(example)
|
|
token_buffer.extend(tokens)
|
|
loss_mask_buffer.extend(loss_masks)
|
|
while len(token_buffer) > chunk_size + 1:
|
|
self._total_tokens += chunk_size
|
|
metrics = {
|
|
'dataset_example_index': index,
|
|
'dataset_total_tokens': self._total_tokens,
|
|
'epoch': self._train_epochs,
|
|
}
|
|
batch = {
|
|
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
|
|
self.config.batch_size, -1
|
|
),
|
|
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
|
|
self.config.batch_size, -1
|
|
),
|
|
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
|
self.config.batch_size, -1
|
|
),
|
|
}
|
|
if self.config.always_start_with_bos:
|
|
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
|
yield batch, metrics
|
|
token_buffer = token_buffer[chunk_size:]
|
|
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
|
|
|
if self._eval_dataset:
|
|
break
|
|
else:
|
|
if self._train_epochs == 0:
|
|
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
|
self._dataset_loc = 0
|
|
self._train_epochs += 1
|
|
|
|
def get_state_dict(self):
|
|
return dict(
|
|
config=self.config,
|
|
dataset_loc=self._index,
|
|
total_tokens=self._total_tokens,
|
|
epochs=self._train_epochs,
|
|
)
|
|
|
|
def load_state_dict(self, state_dict):
|
|
if 'config' in state_dict:
|
|
self.config.update(ConfigDict(state_dict['config']))
|
|
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
|
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
|
self._train_epochs = state_dict.get('epochs', 0)
|
|
if self.reset_dataset_loc:
|
|
self._dataset_loc = 0
|
|
self._train_epochs = 0
|
|
|
|
|
|
@property
|
|
def seq_length(self):
|
|
return self.config.seq_length
|
|
|
|
@property
|
|
def tokenizer(self):
|
|
return self._tokenizer
|
|
|
|
@property
|
|
def text_processor(self):
|
|
return self._text_processor
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self._dataset
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return len(self._tokenizer)
|
|
|
|
|
|
class JsonDataset(object):
|
|
""" JSON dataset, where each line of the data file contains a JSON
|
|
dictionary with text fields.
|
|
"""
|
|
|
|
@staticmethod
|
|
def get_default_config(updates=None):
|
|
config = ConfigDict()
|
|
config.path = ''
|
|
config.seq_length = 1024
|
|
config.batch_size = 8
|
|
config.always_start_with_bos = False
|
|
config.start_seek_loc = 0
|
|
config.example_index_at_start = 0
|
|
config.tokens_count_at_start = 0
|
|
config.tokenizer_processes = 1
|
|
config.tokenizer_parallel_chunk_size = 32
|
|
config.tokenizer_parallel_batch_size = 1024
|
|
config.throughput_average_window_size = 200
|
|
|
|
if updates is not None:
|
|
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
return config
|
|
|
|
def __init__(self, config, tokenizer, text_processor):
|
|
self.config = self.get_default_config(config)
|
|
assert self.config.path != ''
|
|
self._tokenizer = tokenizer
|
|
self._text_processor = text_processor
|
|
self._index = self.config.example_index_at_start
|
|
self._file_loc = self.config.start_seek_loc
|
|
self._total_tokens = self.config.tokens_count_at_start
|
|
|
|
def parse_json(self, line):
|
|
if not line or line == '\n':
|
|
return None
|
|
try:
|
|
data = json.loads(line)
|
|
except json.decoder.JSONDecodeError:
|
|
print(f'Error parsing json line:\n{line}')
|
|
return None
|
|
return data
|
|
|
|
def json_iterator(self):
|
|
with mlxu.open_file(self.config.path, 'r') as fin:
|
|
fin.seek(self._file_loc)
|
|
while True:
|
|
line = fin.readline()
|
|
self._file_loc = fin.tell()
|
|
if not line: # Reached EOF
|
|
self._index = 0
|
|
fin.seek(0)
|
|
continue
|
|
|
|
data = self.parse_json(line)
|
|
if data is not None:
|
|
# JSON parsing succeeded
|
|
yield data, self._file_loc, self._index
|
|
self._index += 1
|
|
|
|
def batched(self, iterator, batch_size):
|
|
batch = []
|
|
for example in iterator:
|
|
batch.append(example)
|
|
if len(batch) == batch_size:
|
|
yield batch
|
|
batch = []
|
|
if len(batch) > 0:
|
|
yield batch
|
|
|
|
def parallel_example_iterator(self):
|
|
if self.config.tokenizer_processes == 1:
|
|
for example, loc, index in self.json_iterator():
|
|
yield self.text_processor((example, loc, index), has_aux=True)
|
|
else:
|
|
process_pool = Pool(self.config.tokenizer_processes)
|
|
batched_iterator = self.batched(
|
|
self.json_iterator(), self.config.tokenizer_parallel_batch_size
|
|
)
|
|
with process_pool as pool:
|
|
map_fn = partial(self.text_processor, has_aux=True)
|
|
next_batch = pool.map_async(
|
|
map_fn, next(batched_iterator),
|
|
chunksize=self.config.tokenizer_parallel_chunk_size
|
|
)
|
|
while True:
|
|
current_batch = next_batch
|
|
next_batch = pool.map_async(
|
|
map_fn, next(batched_iterator),
|
|
chunksize=self.config.tokenizer_parallel_chunk_size
|
|
)
|
|
for example in current_batch.get():
|
|
yield example
|
|
|
|
def __iter__(self):
|
|
chunk_size = self.config.batch_size * self.config.seq_length
|
|
token_buffer = []
|
|
loss_mask_buffer = []
|
|
last_time = 0.0
|
|
step_times = []
|
|
start_time = time.time()
|
|
start_tokens = self._total_tokens
|
|
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
|
|
token_buffer.extend(tokens)
|
|
loss_mask_buffer.extend(loss_masks)
|
|
while len(token_buffer) > chunk_size + 1:
|
|
self._total_tokens += chunk_size
|
|
step_times.append(time.time() - last_time)
|
|
last_time = time.time()
|
|
if len(step_times) > self.config.throughput_average_window_size:
|
|
step_times = step_times[-self.config.throughput_average_window_size:]
|
|
average_throughput = chunk_size / np.mean(step_times)
|
|
accumulated_throughput = (
|
|
(self._total_tokens - start_tokens) / (time.time() - start_time)
|
|
)
|
|
metrics = {
|
|
'dataset_file_loc': loc,
|
|
'dataset_example_index': index,
|
|
'dataset_total_tokens': self._total_tokens,
|
|
'dataset_accumulated_tps': accumulated_throughput,
|
|
'dataset_average_tps': average_throughput,
|
|
}
|
|
batch = {
|
|
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
|
|
self.config.batch_size, -1
|
|
),
|
|
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
|
|
self.config.batch_size, -1
|
|
),
|
|
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
|
self.config.batch_size, -1
|
|
),
|
|
}
|
|
if self.config.always_start_with_bos:
|
|
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
|
yield batch, metrics
|
|
token_buffer = token_buffer[chunk_size:]
|
|
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
|
|
|
def get_state_dict(self):
|
|
return dict(
|
|
config=self.config,
|
|
index=self._index,
|
|
file_loc=self._file_loc,
|
|
total_tokens=self._total_tokens,
|
|
)
|
|
|
|
def load_state_dict(self, state_dict):
|
|
if 'config' in state_dict:
|
|
self.config.update(ConfigDict(state_dict['config']))
|
|
self._index = state_dict.get('index', self.config.example_index_at_start)
|
|
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
|
|
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
|
|
|
@property
|
|
def seq_length(self):
|
|
return self.config.seq_length
|
|
|
|
@property
|
|
def tokenizer(self):
|
|
return self._tokenizer
|
|
|
|
@property
|
|
def text_processor(self):
|
|
return self._text_processor
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return len(self.tokenizer)
|