48 lines
1.7 KiB
Python
48 lines
1.7 KiB
Python
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
||
|
|
from datasets import load_dataset, load_from_disk
|
||
|
|
|
||
|
|
dataset = load_from_disk('finn_wake_dataset')
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained("tinyllama/tinyllama-1.1b-chat-v1.0")
|
||
|
|
|
||
|
|
tokenizer.save_pretrained(".results/checkpoint-12000/")
|
||
|
|
|
||
|
|
model = AutoModelForCausalLM.from_pretrained("tinyllama/tinyllama-1.1b-chat-v1.0")
|
||
|
|
|
||
|
|
if tokenizer.pad_token is None:
|
||
|
|
print("Tokenizer does not have a pad token set. Setting pad_token to eos_token.")
|
||
|
|
tokenizer.pad_token = tokenizer.eos_token
|
||
|
|
|
||
|
|
def tokenize_function(examples):
|
||
|
|
|
||
|
|
tokenized_inputs = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
|
||
|
|
tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()
|
||
|
|
|
||
|
|
return tokenized_inputs
|
||
|
|
|
||
|
|
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||
|
|
train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
|
||
|
|
|
||
|
|
train_dataset = train_test_split['train']
|
||
|
|
eval_dataset = train_test_split['test']
|
||
|
|
training_args = TrainingArguments(
|
||
|
|
output_dir="./results",
|
||
|
|
num_train_epochs=3,
|
||
|
|
per_device_train_batch_size=1,
|
||
|
|
warmup_steps=500,
|
||
|
|
weight_decay=0.01,
|
||
|
|
logging_dir="./logs",
|
||
|
|
logging_steps=10,
|
||
|
|
save_strategy="steps",
|
||
|
|
save_steps=500,
|
||
|
|
save_total_limit=2,
|
||
|
|
use_cpu=True)
|
||
|
|
trainer = Trainer(
|
||
|
|
model=model,
|
||
|
|
args=training_args,
|
||
|
|
train_dataset=train_dataset,
|
||
|
|
eval_dataset=eval_dataset,
|
||
|
|
)
|
||
|
|
#below has been modified because i ran out of disk storage initially so had to resume and adjust the save_strategy above.
|
||
|
|
trainer.train(resume_from_checkpoint="./results/checkpoint-10000")
|