初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
436
EasyLM/data.py
Normal file
436
EasyLM/data.py
Normal file
@@ -0,0 +1,436 @@
|
||||
import dataclasses
|
||||
import pprint
|
||||
import time
|
||||
from functools import partial
|
||||
import json
|
||||
import base64
|
||||
from multiprocessing import Pool
|
||||
|
||||
import h5py
|
||||
import mlxu
|
||||
from ml_collections.config_dict import config_dict
|
||||
from ml_collections import ConfigDict
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
|
||||
class DatasetFactory(object):
|
||||
""" Datset builder class. """
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.type = 'huggingface'
|
||||
config.text_processor = TextProcessor.get_default_config()
|
||||
config.huggingface_dataset = HuggingfaceDataset.get_default_config()
|
||||
config.json_dataset = JsonDataset.get_default_config()
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def load_dataset(cls, config, tokenizer, **kwargs):
|
||||
config = cls.get_default_config(config)
|
||||
text_processor = TextProcessor(config.text_processor, tokenizer)
|
||||
if config.type == 'huggingface':
|
||||
return HuggingfaceDataset(
|
||||
config.huggingface_dataset, tokenizer, text_processor, **kwargs
|
||||
)
|
||||
elif config.type == 'json':
|
||||
return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unknown dataset type: {config.type}')
|
||||
|
||||
def __init__(self):
|
||||
raise ValueError('DatasetFactory is a static class and should not be instantiated.')
|
||||
|
||||
|
||||
class TextProcessor(object):
|
||||
""" Example processor that converts a dictionary of texts into tokens. """
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.fields_from_example = ''
|
||||
config.fields = ''
|
||||
config.subfield_separator = ' '
|
||||
config.add_bos_token = True
|
||||
config.add_eos_token = True
|
||||
config.prepend_text = ''
|
||||
config.base64_token_dtype = 'i4'
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config, tokenizer):
|
||||
self.config = self.get_default_config(config)
|
||||
assert self.config.fields != '' or self.config.fields_from_example != '', (
|
||||
'Either fields or fields_from_example must be specified.'
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, example, has_aux=False):
|
||||
if has_aux:
|
||||
example, *aux = example
|
||||
else:
|
||||
aux = tuple()
|
||||
token_buffer = []
|
||||
loss_mask_buffer = []
|
||||
|
||||
if self.config.add_bos_token:
|
||||
token_buffer.append(self.tokenizer.bos_token_id)
|
||||
loss_mask_buffer.append(0.0)
|
||||
|
||||
if self.config.fields_from_example != '':
|
||||
fields = example[self.config.fields_from_example].split(',')
|
||||
else:
|
||||
fields = self.config.fields.split(',')
|
||||
|
||||
for i, field in enumerate(fields):
|
||||
if field.startswith('[') and field.endswith(']'):
|
||||
# No loss for this field.
|
||||
field = field[1:-1]
|
||||
mask = 0.0
|
||||
else:
|
||||
mask = 1.0
|
||||
|
||||
if field.startswith('<|') and field.endswith('|>'):
|
||||
# Special tokens.
|
||||
field = field[2:-2]
|
||||
if field == 'bos':
|
||||
token_buffer.append(self.tokenizer.bos_token_id)
|
||||
elif field == 'eos':
|
||||
token_buffer.append(self.tokenizer.eos_token_id)
|
||||
else:
|
||||
# Token ID specified directly.
|
||||
token_buffer.append(int(field))
|
||||
loss_mask_buffer.append(mask)
|
||||
elif field.startswith('{') and field.endswith('}'):
|
||||
field = field[1:-1]
|
||||
# Base64 encoded raw tokens.
|
||||
tokens = np.frombuffer(
|
||||
base64.b64decode(example[field]),
|
||||
dtype=self.config.base64_token_dtype
|
||||
).tolist()
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
||||
else:
|
||||
subfields = field.split('+')
|
||||
text = self.config.subfield_separator.join(
|
||||
[example[subfield] for subfield in subfields]
|
||||
)
|
||||
if i == 0:
|
||||
text = self.config.prepend_text + text
|
||||
tokens = self.tokenizer.encode(text)
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend([mask for _ in range(len(tokens))])
|
||||
|
||||
if self.config.add_eos_token:
|
||||
token_buffer.append(self.tokenizer.eos_token_id)
|
||||
loss_mask_buffer.append(1.0)
|
||||
|
||||
return token_buffer, loss_mask_buffer, *aux
|
||||
|
||||
|
||||
class HuggingfaceDataset(object):
|
||||
""" Huggingface dataset, where the dataset is loaded using the huggingface
|
||||
datasets.load_dataset() function.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.path = 'c4'
|
||||
config.name = 'en'
|
||||
config.split = 'train'
|
||||
config.streaming = False
|
||||
config.seq_length = 1024
|
||||
config.batch_size = 8
|
||||
config.always_start_with_bos = False
|
||||
config.start_seek_loc = 0
|
||||
config.tokens_count_at_start = 0
|
||||
config.batch_token_dtype = 'i4'
|
||||
config.reset_dataset_loc = False
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
|
||||
self.config = self.get_default_config(config)
|
||||
name = self.config.name if self.config.name != '' else None
|
||||
split = self.config.split if self.config.split != '' else None
|
||||
self._tokenizer = tokenizer
|
||||
self._text_processor = text_processor
|
||||
self._dataset = load_from_disk(
|
||||
self.config.path
|
||||
)[split]
|
||||
self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
|
||||
self._eval_dataset = eval_dataset
|
||||
self._train_epochs = 0
|
||||
self._dataset_loc = self.config.start_seek_loc
|
||||
self._total_tokens = self.config.tokens_count_at_start
|
||||
self._index = 0
|
||||
self.reset_dataset_loc = self.config.reset_dataset_loc
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
if not self._eval_dataset and self._train_epochs > 0:
|
||||
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
||||
chunk_size = self.config.batch_size * self.config.seq_length
|
||||
while True:
|
||||
token_buffer = []
|
||||
loss_mask_buffer = []
|
||||
if not self._eval_dataset and self._train_epochs > 0:
|
||||
self._dataset.set_epoch(self._train_epochs)
|
||||
for index, example in enumerate(self._dataset):
|
||||
self._index = index
|
||||
if not self._eval_dataset and self._dataset_loc > index:
|
||||
continue
|
||||
tokens, loss_masks = self.text_processor(example)
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend(loss_masks)
|
||||
while len(token_buffer) > chunk_size + 1:
|
||||
self._total_tokens += chunk_size
|
||||
metrics = {
|
||||
'dataset_example_index': index,
|
||||
'dataset_total_tokens': self._total_tokens,
|
||||
'epoch': self._train_epochs,
|
||||
}
|
||||
batch = {
|
||||
'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
}
|
||||
if self.config.always_start_with_bos:
|
||||
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
||||
yield batch, metrics
|
||||
token_buffer = token_buffer[chunk_size:]
|
||||
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
||||
|
||||
if self._eval_dataset:
|
||||
break
|
||||
else:
|
||||
if self._train_epochs == 0:
|
||||
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
||||
self._dataset_loc = 0
|
||||
self._train_epochs += 1
|
||||
|
||||
def get_state_dict(self):
|
||||
return dict(
|
||||
config=self.config,
|
||||
dataset_loc=self._index,
|
||||
total_tokens=self._total_tokens,
|
||||
epochs=self._train_epochs,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if 'config' in state_dict:
|
||||
self.config.update(ConfigDict(state_dict['config']))
|
||||
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
||||
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
||||
self._train_epochs = state_dict.get('epochs', 0)
|
||||
if self.reset_dataset_loc:
|
||||
self._dataset_loc = 0
|
||||
self._train_epochs = 0
|
||||
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
return self.config.seq_length
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def text_processor(self):
|
||||
return self._text_processor
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self._dataset
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self._tokenizer)
|
||||
|
||||
|
||||
class JsonDataset(object):
|
||||
""" JSON dataset, where each line of the data file contains a JSON
|
||||
dictionary with text fields.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_default_config(updates=None):
|
||||
config = ConfigDict()
|
||||
config.path = ''
|
||||
config.seq_length = 1024
|
||||
config.batch_size = 8
|
||||
config.always_start_with_bos = False
|
||||
config.start_seek_loc = 0
|
||||
config.example_index_at_start = 0
|
||||
config.tokens_count_at_start = 0
|
||||
config.tokenizer_processes = 1
|
||||
config.tokenizer_parallel_chunk_size = 32
|
||||
config.tokenizer_parallel_batch_size = 1024
|
||||
config.throughput_average_window_size = 200
|
||||
|
||||
if updates is not None:
|
||||
config.update(ConfigDict(updates).copy_and_resolve_references())
|
||||
return config
|
||||
|
||||
def __init__(self, config, tokenizer, text_processor):
|
||||
self.config = self.get_default_config(config)
|
||||
assert self.config.path != ''
|
||||
self._tokenizer = tokenizer
|
||||
self._text_processor = text_processor
|
||||
self._index = self.config.example_index_at_start
|
||||
self._file_loc = self.config.start_seek_loc
|
||||
self._total_tokens = self.config.tokens_count_at_start
|
||||
|
||||
def parse_json(self, line):
|
||||
if not line or line == '\n':
|
||||
return None
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.decoder.JSONDecodeError:
|
||||
print(f'Error parsing json line:\n{line}')
|
||||
return None
|
||||
return data
|
||||
|
||||
def json_iterator(self):
|
||||
with mlxu.open_file(self.config.path, 'r') as fin:
|
||||
fin.seek(self._file_loc)
|
||||
while True:
|
||||
line = fin.readline()
|
||||
self._file_loc = fin.tell()
|
||||
if not line: # Reached EOF
|
||||
self._index = 0
|
||||
fin.seek(0)
|
||||
continue
|
||||
|
||||
data = self.parse_json(line)
|
||||
if data is not None:
|
||||
# JSON parsing succeeded
|
||||
yield data, self._file_loc, self._index
|
||||
self._index += 1
|
||||
|
||||
def batched(self, iterator, batch_size):
|
||||
batch = []
|
||||
for example in iterator:
|
||||
batch.append(example)
|
||||
if len(batch) == batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def parallel_example_iterator(self):
|
||||
if self.config.tokenizer_processes == 1:
|
||||
for example, loc, index in self.json_iterator():
|
||||
yield self.text_processor((example, loc, index), has_aux=True)
|
||||
else:
|
||||
process_pool = Pool(self.config.tokenizer_processes)
|
||||
batched_iterator = self.batched(
|
||||
self.json_iterator(), self.config.tokenizer_parallel_batch_size
|
||||
)
|
||||
with process_pool as pool:
|
||||
map_fn = partial(self.text_processor, has_aux=True)
|
||||
next_batch = pool.map_async(
|
||||
map_fn, next(batched_iterator),
|
||||
chunksize=self.config.tokenizer_parallel_chunk_size
|
||||
)
|
||||
while True:
|
||||
current_batch = next_batch
|
||||
next_batch = pool.map_async(
|
||||
map_fn, next(batched_iterator),
|
||||
chunksize=self.config.tokenizer_parallel_chunk_size
|
||||
)
|
||||
for example in current_batch.get():
|
||||
yield example
|
||||
|
||||
def __iter__(self):
|
||||
chunk_size = self.config.batch_size * self.config.seq_length
|
||||
token_buffer = []
|
||||
loss_mask_buffer = []
|
||||
last_time = 0.0
|
||||
step_times = []
|
||||
start_time = time.time()
|
||||
start_tokens = self._total_tokens
|
||||
for tokens, loss_masks, loc, index in self.parallel_example_iterator():
|
||||
token_buffer.extend(tokens)
|
||||
loss_mask_buffer.extend(loss_masks)
|
||||
while len(token_buffer) > chunk_size + 1:
|
||||
self._total_tokens += chunk_size
|
||||
step_times.append(time.time() - last_time)
|
||||
last_time = time.time()
|
||||
if len(step_times) > self.config.throughput_average_window_size:
|
||||
step_times = step_times[-self.config.throughput_average_window_size:]
|
||||
average_throughput = chunk_size / np.mean(step_times)
|
||||
accumulated_throughput = (
|
||||
(self._total_tokens - start_tokens) / (time.time() - start_time)
|
||||
)
|
||||
metrics = {
|
||||
'dataset_file_loc': loc,
|
||||
'dataset_example_index': index,
|
||||
'dataset_total_tokens': self._total_tokens,
|
||||
'dataset_accumulated_tps': accumulated_throughput,
|
||||
'dataset_average_tps': average_throughput,
|
||||
}
|
||||
batch = {
|
||||
'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
|
||||
self.config.batch_size, -1
|
||||
),
|
||||
}
|
||||
if self.config.always_start_with_bos:
|
||||
batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
|
||||
yield batch, metrics
|
||||
token_buffer = token_buffer[chunk_size:]
|
||||
loss_mask_buffer = loss_mask_buffer[chunk_size:]
|
||||
|
||||
def get_state_dict(self):
|
||||
return dict(
|
||||
config=self.config,
|
||||
index=self._index,
|
||||
file_loc=self._file_loc,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if 'config' in state_dict:
|
||||
self.config.update(ConfigDict(state_dict['config']))
|
||||
self._index = state_dict.get('index', self.config.example_index_at_start)
|
||||
self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
|
||||
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
||||
|
||||
@property
|
||||
def seq_length(self):
|
||||
return self.config.seq_length
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
return self._tokenizer
|
||||
|
||||
@property
|
||||
def text_processor(self):
|
||||
return self._text_processor
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self.tokenizer)
|
||||
Reference in New Issue
Block a user