初始化项目,由ModelHub XC社区提供模型
Model: Oysiyl/qwen3-4b-unslop-good-lora-v1 Source: Original Platform
This commit is contained in:
57
run_local_inference_new_sample_mac.py
Normal file
57
run_local_inference_new_sample_mac.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
REPO = 'Oysiyl/qwen3-4b-unslop-good-lora-v1'
|
||||
BASE = 'Qwen/Qwen3-4B'
|
||||
WORKDIR = Path('./local_inference_model_4b')
|
||||
PROMPT = (Path(__file__).resolve().parent / 'new_eval_sample.txt').read_text(encoding='utf-8').strip()
|
||||
|
||||
WORKDIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
repo_files = [
|
||||
'tokenizer.json',
|
||||
'tokenizer_config.json',
|
||||
'chat_template.jinja',
|
||||
]
|
||||
weight_files = [
|
||||
'model-00001-of-00002.safetensors',
|
||||
'model-00002-of-00002.safetensors',
|
||||
'model.safetensors.index.json',
|
||||
]
|
||||
|
||||
for fname in repo_files + weight_files:
|
||||
src = hf_hub_download(repo_id=REPO, filename=fname)
|
||||
shutil.copyfile(src, WORKDIR / fname)
|
||||
|
||||
for fname in ['config.json', 'generation_config.json']:
|
||||
src = hf_hub_download(repo_id=REPO, filename=fname)
|
||||
shutil.copyfile(src, WORKDIR / fname)
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Polish this AI passage to feel more human while preserving meaning:\n' + PROMPT,
|
||||
}
|
||||
]
|
||||
|
||||
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||||
print(f'Using device: {device}')
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(WORKDIR)
|
||||
model = AutoModelForCausalLM.from_pretrained(WORKDIR, torch_dtype='auto')
|
||||
model = model.to(device)
|
||||
|
||||
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = tokenizer(text, return_tensors='pt').to(device)
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=400,
|
||||
temperature=0.8,
|
||||
repetition_penalty=1.1,
|
||||
do_sample=True,
|
||||
)
|
||||
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
||||
Reference in New Issue
Block a user