26 lines
703 B
Python
26 lines
703 B
Python
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
import torch
|
||
|
||
|
||
|
||
# 模型路径
|
||
model_path = "./"
|
||
|
||
# 加载 tokenizer (分词器)
|
||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||
|
||
# 加载模型并移动到可用设备(GPU/CPU)
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
|
||
|
||
# 使用 tokenizer 编码输入的 prompt
|
||
inputs = tokenizer("你是雫梨梨吗", return_tensors="pt").to(device)
|
||
|
||
# 使用模型生成文本
|
||
outputs = model.generate(inputs["input_ids"], max_length=150)
|
||
|
||
# 解码生成的输出
|
||
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
||
print(generated_text)
|