初始化项目,由ModelHub XC社区提供模型
Model: OdaxAI/DANTE-Mosaic-3.5B Source: Original Platform
This commit is contained in:
73
example_inference.py
Normal file
73
example_inference.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Example inference for DANTE-Mosaic-3.5B.
|
||||
|
||||
Usage:
|
||||
python example_inference.py
|
||||
python example_inference.py --model YourOrg/DANTE-Mosaic-3.5B
|
||||
python example_inference.py --model ./local_path/
|
||||
|
||||
Run on a single A100 / RTX 4090 / H100. ~5.8 GB VRAM in BF16.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
PROMPTS = [
|
||||
("MATH", "What is the derivative of f(x) = x^3 + 2x^2 - 5x + 1? Show step by step."),
|
||||
("CODE", "Write a Python function that checks if a string is a palindrome. Include a docstring and edge cases."),
|
||||
("LOGIC", "A bat and a ball cost $1.10 in total. The bat costs $1.00 more than the ball. How much does the ball cost? Explain."),
|
||||
("ITA", "Spiega cos'è il machine learning in termini semplici, adatti a uno studente delle superiori."),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", default="./",
|
||||
help="HF repo id or local path to the model directory")
|
||||
p.add_argument("--max-new-tokens", type=int, default=256)
|
||||
p.add_argument("--temperature", type=float, default=0.7)
|
||||
p.add_argument("--top-p", type=float, default=0.9)
|
||||
args = p.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Loading {args.model} on {device} ...")
|
||||
t0 = time.time()
|
||||
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
).eval()
|
||||
print(f"Loaded in {time.time()-t0:.1f}s "
|
||||
f"({sum(p.numel() for p in model.parameters())/1e9:.2f}B params)\n")
|
||||
|
||||
for tag, prompt in PROMPTS:
|
||||
print("─" * 60)
|
||||
print(f"[{tag}] {prompt}\n")
|
||||
inputs = tok(prompt, return_tensors="pt").to(model.device)
|
||||
plen = inputs["input_ids"].shape[-1]
|
||||
t0 = time.time()
|
||||
with torch.no_grad():
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
do_sample=True,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
repetition_penalty=1.1,
|
||||
pad_token_id=tok.eos_token_id,
|
||||
)
|
||||
new_toks = out.shape[-1] - plen
|
||||
elapsed = time.time() - t0
|
||||
text = tok.decode(out[0][plen:], skip_special_tokens=True).strip()
|
||||
print(text)
|
||||
print(f"\n [{new_toks} tokens in {elapsed:.1f}s — {new_toks/elapsed:.1f} tok/s]\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user