初始化项目,由ModelHub XC社区提供模型
Model: npc-worldwide/TinyTimV1 Source: Original Platform
This commit is contained in:
47
fine_tune_joyce.py
Normal file
47
fine_tune_joyce.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
dataset = load_from_disk('finn_wake_dataset')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("tinyllama/tinyllama-1.1b-chat-v1.0")
|
||||
|
||||
tokenizer.save_pretrained(".results/checkpoint-12000/")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("tinyllama/tinyllama-1.1b-chat-v1.0")
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
print("Tokenizer does not have a pad token set. Setting pad_token to eos_token.")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def tokenize_function(examples):
|
||||
|
||||
tokenized_inputs = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
|
||||
tokenized_inputs["labels"] = tokenized_inputs["input_ids"].copy()
|
||||
|
||||
return tokenized_inputs
|
||||
|
||||
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
train_test_split = tokenized_dataset.train_test_split(test_size=0.1)
|
||||
|
||||
train_dataset = train_test_split['train']
|
||||
eval_dataset = train_test_split['test']
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=1,
|
||||
warmup_steps=500,
|
||||
weight_decay=0.01,
|
||||
logging_dir="./logs",
|
||||
logging_steps=10,
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
save_total_limit=2,
|
||||
use_cpu=True)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
)
|
||||
#below has been modified because i ran out of disk storage initially so had to resume and adjust the save_strategy above.
|
||||
trainer.train(resume_from_checkpoint="./results/checkpoint-10000")
|
||||
Reference in New Issue
Block a user