58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
from pathlib import Path
|
|
import shutil
|
|
import torch
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
REPO = 'Oysiyl/qwen3-4b-unslop-good-lora-v1'
|
|
BASE = 'Qwen/Qwen3-4B'
|
|
WORKDIR = Path('./local_inference_model_4b')
|
|
PROMPT = (Path(__file__).resolve().parent / 'new_eval_sample.txt').read_text(encoding='utf-8').strip()
|
|
|
|
WORKDIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
repo_files = [
|
|
'tokenizer.json',
|
|
'tokenizer_config.json',
|
|
'chat_template.jinja',
|
|
]
|
|
weight_files = [
|
|
'model-00001-of-00002.safetensors',
|
|
'model-00002-of-00002.safetensors',
|
|
'model.safetensors.index.json',
|
|
]
|
|
|
|
for fname in repo_files + weight_files:
|
|
src = hf_hub_download(repo_id=REPO, filename=fname)
|
|
shutil.copyfile(src, WORKDIR / fname)
|
|
|
|
for fname in ['config.json', 'generation_config.json']:
|
|
src = hf_hub_download(repo_id=REPO, filename=fname)
|
|
shutil.copyfile(src, WORKDIR / fname)
|
|
|
|
messages = [
|
|
{
|
|
'role': 'user',
|
|
'content': 'Polish this AI passage to feel more human while preserving meaning:\n' + PROMPT,
|
|
}
|
|
]
|
|
|
|
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
|
|
print(f'Using device: {device}')
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(WORKDIR)
|
|
model = AutoModelForCausalLM.from_pretrained(WORKDIR, torch_dtype='auto')
|
|
model = model.to(device)
|
|
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
inputs = tokenizer(text, return_tensors='pt').to(device)
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=400,
|
|
temperature=0.8,
|
|
repetition_penalty=1.1,
|
|
do_sample=True,
|
|
)
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|