初始化项目,由ModelHub XC社区提供模型
Model: npc-worldwide/TinyTimV1 Source: Original Platform
This commit is contained in:
27
text_gen.py
Normal file
27
text_gen.py
Normal file
@@ -0,0 +1,27 @@
|
||||
|
||||
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
|
||||
from datasets import load_dataset, load_from_disk
|
||||
|
||||
#post training
|
||||
model_path = "./results/checkpoint-12000"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained("tinyllama/tinyllama-1.1b-chat-v1.0")
|
||||
|
||||
|
||||
input_text = "ae left to go to ireland and found a fairy"
|
||||
input_ids = tokenizer.encode(input_text, return_tensors='pt')
|
||||
output = model.generate(
|
||||
input_ids=tokenizer.encode(input_text, return_tensors="pt"),
|
||||
max_length=400,
|
||||
num_return_sequences=1,
|
||||
temperature=0.7,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
do_sample=True,
|
||||
num_beams=5
|
||||
)
|
||||
|
||||
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
print(decoded_output)
|
||||
Reference in New Issue
Block a user