Files
ModelHub XC 0dcbf77363 初始化项目,由ModelHub XC社区提供模型
Model: fireballoon/baichuan-vicuna-chinese-7b
Source: Original Platform
2026-06-08 19:55:17 +08:00

84 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from threading import Thread
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
model_id = "fireballoon/baichuan-vicuna-chinese-7b"
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
if torch_device == "cuda":
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16).cuda()
else:
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
def run_generation(history, *args, **kwargs):
# Get the model and tokenizer, and tokenize the user text.
instruction = "A chat between a curious user and an artificial intelligence assistant. " \
"The assistant gives helpful, detailed, and polite answers to the user's questions."
context = ''.join([f" USER: {turn[0].strip()} ASSISTANT: {turn[1].strip()} </s>" for turn in history[:-1]])
prompt = instruction + context + f" USER: {history[-1][0].strip()} ASSISTANT:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
print()
print(prompt)
print('##', input_ids.size())
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=2048,
do_sample=True,
temperature=0.7,
repetition_penalty=1.1,
top_p=0.85
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
# Pull the generated text from the streamer, and update the model output.
history[-1][1] = ""
print("")
for new_text in streamer:
history[-1][1] += new_text
print(new_text, end="", flush=True)
yield history
print('</s>')
return history
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
gr.Markdown(
"# Baichuan Vicuna Chinese\n"
f"[{model_id}](https://huggingface.co/{model_id})使用中英双语sharegpt数据全参数微调的对话模型基于baichuan-7b"
)
chatbot = gr.Chatbot().style(height=600)
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
def user(user_message, history):
return gr.update(value="", interactive=False), history + [[user_message, None]]
response = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
run_generation, chatbot, chatbot
)
response.then(lambda: gr.update(interactive=True), None, [msg], queue=False)
demo.queue()
demo.launch(server_name='0.0.0.0')