26 lines
703 B
Python
26 lines
703 B
Python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 模型路径
|
|||
|
|
model_path = "./"
|
|||
|
|
|
|||
|
|
# 加载 tokenizer (分词器)
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
|||
|
|
|
|||
|
|
# 加载模型并移动到可用设备(GPU/CPU)
|
|||
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
|
|||
|
|
|
|||
|
|
# 使用 tokenizer 编码输入的 prompt
|
|||
|
|
inputs = tokenizer("你是雫梨梨吗", return_tensors="pt").to(device)
|
|||
|
|
|
|||
|
|
# 使用模型生成文本
|
|||
|
|
outputs = model.generate(inputs["input_ids"], max_length=150)
|
|||
|
|
|
|||
|
|
# 解码生成的输出
|
|||
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|||
|
|
|
|||
|
|
print(generated_text)
|