82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
|
||
|
|
BASE_PATH = "/fsx/loubna/projects/alignment-handbook/recipes/cosmo2/sft/data"
|
||
|
|
TEMPERATURE = 0.2
|
||
|
|
TOP_P = 0.9
|
||
|
|
|
||
|
|
CHECKPOINT = "HuggingFaceTB/smollm-350M-instruct-add-basics"
|
||
|
|
|
||
|
|
print(f"💾 Loading the model and tokenizer: {CHECKPOINT}...")
|
||
|
|
device = "cuda"
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
|
||
|
|
model_s = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(device)
|
||
|
|
|
||
|
|
print("🧪 Testing single-turn conversations...")
|
||
|
|
L = [
|
||
|
|
# Witing and general knowledge prompts
|
||
|
|
"Discuss the ethical implications of using AI in hiring processes.",
|
||
|
|
"Give me some tips to improve my time management skills?",
|
||
|
|
"Write a short dialogue between a customer and a waiter at a restaurant.",
|
||
|
|
"wassup?",
|
||
|
|
"Tell me a joke",
|
||
|
|
"Hi, what are some popular dishes from Japan?",
|
||
|
|
"What is the capital of Switzerland?",
|
||
|
|
"What is the capital of France?",
|
||
|
|
"What's the capital of Portugal?",
|
||
|
|
"What is the capital of Morocco?",
|
||
|
|
"How do I make pancakes?",
|
||
|
|
"Write a poem about Helium",
|
||
|
|
"Do you think it's important for a company to have a strong company culture? Why or why not?",
|
||
|
|
"What is your favorite book?",
|
||
|
|
"What is the most interesting fact you know?",
|
||
|
|
"What is your favorite movie?",
|
||
|
|
|
||
|
|
# Science prompts
|
||
|
|
"Can you tell me what is gravity?",
|
||
|
|
"Who discovered gravity?",
|
||
|
|
"How does a rainbow form?",
|
||
|
|
"What are the three states of matter?",
|
||
|
|
"Why is the sky blue?",
|
||
|
|
"What is the water cycle?",
|
||
|
|
"How do magnets work?",
|
||
|
|
"What is buoyancy?",
|
||
|
|
"What is the speed of light?",
|
||
|
|
"What's 2+2?",
|
||
|
|
"what's the sum of 2 and 2?",
|
||
|
|
"what's the sum of 2 and 3?",
|
||
|
|
"What is the term for the process by which plants make their own food?",
|
||
|
|
"If you have 8 apples and you give away 3, how many apples do you have left?",
|
||
|
|
|
||
|
|
# Python prompts
|
||
|
|
"How do I define a function in Python?",
|
||
|
|
"Can you explain what a dictionary is in Python?",
|
||
|
|
"Write a sort alrogithm in Python",
|
||
|
|
"Write a fibonacci sequence in Python",
|
||
|
|
"How do I read a file in Python?",
|
||
|
|
"How do I make everything uppercase in Python?",
|
||
|
|
"implement bubble sort in Python",
|
||
|
|
|
||
|
|
# Creative prompts
|
||
|
|
"Write a short story about a time traveler",
|
||
|
|
"Describe a futuristic city in three sentences",
|
||
|
|
"Describe a new color that doesn't exist",
|
||
|
|
"Create a slogan for a time machine company",
|
||
|
|
"Describe a world where plants can speak",
|
||
|
|
]
|
||
|
|
|
||
|
|
for i in range(len(L)):
|
||
|
|
print(f"🔮 {L[i]}")
|
||
|
|
messages = [{"role": "user", "content": L[i]}]
|
||
|
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
|
||
|
|
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
||
|
|
outputs = model_s.generate(
|
||
|
|
inputs, max_new_tokens=200, top_p=TOP_P, do_sample=True, temperature=TEMPERATURE
|
||
|
|
)
|
||
|
|
with open(
|
||
|
|
f"{BASE_PATH}/{CHECKPOINT.split('/')[-1]}_temp_{TEMPERATURE}_topp{TOP_P}.txt",
|
||
|
|
"a",
|
||
|
|
) as f:
|
||
|
|
f.write("=" * 50 + "\n")
|
||
|
|
f.write(tokenizer.decode(outputs[0]))
|
||
|
|
f.write("\n")
|