""" Example inference for DANTE-Mosaic-3.5B. Usage: python example_inference.py python example_inference.py --model YourOrg/DANTE-Mosaic-3.5B python example_inference.py --model ./local_path/ Run on a single A100 / RTX 4090 / H100. ~5.8 GB VRAM in BF16. """ from __future__ import annotations import argparse import time import torch from transformers import AutoModelForCausalLM, AutoTokenizer PROMPTS = [ ("MATH", "What is the derivative of f(x) = x^3 + 2x^2 - 5x + 1? Show step by step."), ("CODE", "Write a Python function that checks if a string is a palindrome. Include a docstring and edge cases."), ("LOGIC", "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain."), ("ITA", "Spiega cos'è il machine learning in termini semplici, adatti a uno studente delle superiori."), ] def main(): p = argparse.ArgumentParser() p.add_argument("--model", default="./", help="HF repo id or local path to the model directory") p.add_argument("--max-new-tokens", type=int, default=256) p.add_argument("--temperature", type=float, default=0.7) p.add_argument("--top-p", type=float, default=0.9) args = p.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading {args.model} on {device} ...") t0 = time.time() tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ).eval() print(f"Loaded in {time.time()-t0:.1f}s " f"({sum(p.numel() for p in model.parameters())/1e9:.2f}B params)\n") for tag, prompt in PROMPTS: print("─" * 60) print(f"[{tag}] {prompt}\n") inputs = tok(prompt, return_tensors="pt").to(model.device) plen = inputs["input_ids"].shape[-1] t0 = time.time() with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=args.max_new_tokens, do_sample=True, temperature=args.temperature, top_p=args.top_p, repetition_penalty=1.1, pad_token_id=tok.eos_token_id, ) new_toks = out.shape[-1] - plen elapsed = time.time() - t0 text = tok.decode(out[0][plen:], skip_special_tokens=True).strip() print(text) print(f"\n [{new_toks} tokens in {elapsed:.1f}s — {new_toks/elapsed:.1f} tok/s]\n") if __name__ == "__main__": main()