74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
"""
|
|
Example inference for DANTE-Mosaic-3.5B.
|
|
|
|
Usage:
|
|
python example_inference.py
|
|
python example_inference.py --model YourOrg/DANTE-Mosaic-3.5B
|
|
python example_inference.py --model ./local_path/
|
|
|
|
Run on a single A100 / RTX 4090 / H100. ~5.8 GB VRAM in BF16.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import time
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
PROMPTS = [
|
|
("MATH", "What is the derivative of f(x) = x^3 + 2x^2 - 5x + 1? Show step by step."),
|
|
("CODE", "Write a Python function that checks if a string is a palindrome. Include a docstring and edge cases."),
|
|
("LOGIC", "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain."),
|
|
("ITA", "Spiega cos'è il machine learning in termini semplici, adatti a uno studente delle superiori."),
|
|
]
|
|
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--model", default="./",
|
|
help="HF repo id or local path to the model directory")
|
|
p.add_argument("--max-new-tokens", type=int, default=256)
|
|
p.add_argument("--temperature", type=float, default=0.7)
|
|
p.add_argument("--top-p", type=float, default=0.9)
|
|
args = p.parse_args()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Loading {args.model} on {device} ...")
|
|
t0 = time.time()
|
|
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
).eval()
|
|
print(f"Loaded in {time.time()-t0:.1f}s "
|
|
f"({sum(p.numel() for p in model.parameters())/1e9:.2f}B params)\n")
|
|
|
|
for tag, prompt in PROMPTS:
|
|
print("─" * 60)
|
|
print(f"[{tag}] {prompt}\n")
|
|
inputs = tok(prompt, return_tensors="pt").to(model.device)
|
|
plen = inputs["input_ids"].shape[-1]
|
|
t0 = time.time()
|
|
with torch.no_grad():
|
|
out = model.generate(
|
|
**inputs,
|
|
max_new_tokens=args.max_new_tokens,
|
|
do_sample=True,
|
|
temperature=args.temperature,
|
|
top_p=args.top_p,
|
|
repetition_penalty=1.1,
|
|
pad_token_id=tok.eos_token_id,
|
|
)
|
|
new_toks = out.shape[-1] - plen
|
|
elapsed = time.time() - t0
|
|
text = tok.decode(out[0][plen:], skip_special_tokens=True).strip()
|
|
print(text)
|
|
print(f"\n [{new_toks} tokens in {elapsed:.1f}s — {new_toks/elapsed:.1f} tok/s]\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|