27 lines
786 B
Python
27 lines
786 B
Python
|
|
|
||
|
|
|
||
|
|
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
||
|
|
from datasets import load_dataset, load_from_disk
|
||
|
|
|
||
|
|
#post training
|
||
|
|
model_path = "./results/checkpoint-12000"
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(model_path)
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained("tinyllama/tinyllama-1.1b-chat-v1.0")
|
||
|
|
|
||
|
|
|
||
|
|
input_text = "ae left to go to ireland and found a fairy"
|
||
|
|
input_ids = tokenizer.encode(input_text, return_tensors='pt')
|
||
|
|
output = model.generate(
|
||
|
|
input_ids=tokenizer.encode(input_text, return_tensors="pt"),
|
||
|
|
max_length=400,
|
||
|
|
num_return_sequences=1,
|
||
|
|
temperature=0.7,
|
||
|
|
top_k=50,
|
||
|
|
top_p=0.95,
|
||
|
|
do_sample=True,
|
||
|
|
num_beams=5
|
||
|
|
)
|
||
|
|
|
||
|
|
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
|
||
|
|
print(decoded_output)
|