Files
DANTE-Mosaic-3.5B/example_inference.py
ModelHub XC b0ba87406b 初始化项目,由ModelHub XC社区提供模型
Model: OdaxAI/DANTE-Mosaic-3.5B
Source: Original Platform
2026-05-14 15:44:10 +08:00

74 lines
2.6 KiB
Python

"""
Example inference for DANTE-Mosaic-3.5B.
Usage:
python example_inference.py
python example_inference.py --model YourOrg/DANTE-Mosaic-3.5B
python example_inference.py --model ./local_path/
Run on a single A100 / RTX 4090 / H100. ~5.8 GB VRAM in BF16.
"""
from __future__ import annotations
import argparse
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
PROMPTS = [
("MATH", "What is the derivative of f(x) = x^3 + 2x^2 - 5x + 1? Show step by step."),
("CODE", "Write a Python function that checks if a string is a palindrome. Include a docstring and edge cases."),
("LOGIC", "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain."),
("ITA", "Spiega cos'è il machine learning in termini semplici, adatti a uno studente delle superiori."),
]
def main():
p = argparse.ArgumentParser()
p.add_argument("--model", default="./",
help="HF repo id or local path to the model directory")
p.add_argument("--max-new-tokens", type=int, default=256)
p.add_argument("--temperature", type=float, default=0.7)
p.add_argument("--top-p", type=float, default=0.9)
args = p.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading {args.model} on {device} ...")
t0 = time.time()
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.model,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
).eval()
print(f"Loaded in {time.time()-t0:.1f}s "
f"({sum(p.numel() for p in model.parameters())/1e9:.2f}B params)\n")
for tag, prompt in PROMPTS:
print("" * 60)
print(f"[{tag}] {prompt}\n")
inputs = tok(prompt, return_tensors="pt").to(model.device)
plen = inputs["input_ids"].shape[-1]
t0 = time.time()
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
do_sample=True,
temperature=args.temperature,
top_p=args.top_p,
repetition_penalty=1.1,
pad_token_id=tok.eos_token_id,
)
new_toks = out.shape[-1] - plen
elapsed = time.time() - t0
text = tok.decode(out[0][plen:], skip_special_tokens=True).strip()
print(text)
print(f"\n [{new_toks} tokens in {elapsed:.1f}s — {new_toks/elapsed:.1f} tok/s]\n")
if __name__ == "__main__":
main()