初始化项目,由ModelHub XC社区提供模型
Model: QuantaSparkLabs/Chronos-3B Source: Original Platform
This commit is contained in:
130
pipeline.py
Normal file
130
pipeline.py
Normal file
@@ -0,0 +1,130 @@
|
||||
|
||||
import json, torch, numpy as np
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder
|
||||
import faiss
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
class Chronos:
|
||||
def __init__(self, model_dir="."):
|
||||
with open(f"{model_dir}/rag_config.json") as f:
|
||||
config = json.load(f)
|
||||
self.embedder = SentenceTransformer(config["embedder_model"])
|
||||
self.index = faiss.read_index(f"{model_dir}/jjk_index.faiss")
|
||||
with open(f"{model_dir}/chunks.txt", "r") as f:
|
||||
raw = f.read().split("<|CHUNK_END|>")
|
||||
self.chunks = [c.strip() for c in raw if c.strip()]
|
||||
self.reranker = CrossEncoder(f"{model_dir}/cross_encoder_model")
|
||||
bnb = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir,
|
||||
quantization_config=bnb,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
def ask(self, question, max_tokens_casual=200, max_tokens_hist=400):
|
||||
q = question.strip().lower().rstrip("?!.")
|
||||
greetings = [
|
||||
"hi", "hello", "hey", "yo", "sup", "good morning",
|
||||
"good evening", "how are you", "hi bro", "hi there",
|
||||
"what's up", "howdy"
|
||||
]
|
||||
if any(q == g or q.startswith(g) for g in greetings):
|
||||
return "Hello! I'm Chronos, your 20th‑century historian. Ask me anything about the wars, inventions, and events that shaped the modern world. Or just chat – I'm pretty friendly! 😄"
|
||||
|
||||
identity_qs = [
|
||||
"who are you", "what is your name", "what are you",
|
||||
"are you chronos", "your name", "who is chronos",
|
||||
"what is chronos", "tell me about yourself",
|
||||
"what do you do", "who created you", "who made you"
|
||||
]
|
||||
if any(idq in q for idq in identity_qs):
|
||||
return "I'm Chronos 🕰️, a friendly historian AI. I can talk about whatever you like, but my real passion is 20th‑century history – WW1, WW2, the Cold War, and all the inventions that came out of that era. So, what's on your mind?"
|
||||
|
||||
safety_net = {
|
||||
"leader of germany during world war i": "Kaiser Wilhelm II was the German Emperor during World War I.",
|
||||
"dictator of germany during world war ii": "Adolf Hitler was the dictator of Germany during World War II.",
|
||||
"leader of germany during ww2": "Adolf Hitler was the dictator of Germany during World War II.",
|
||||
"leader of the soviet union during world war ii": "Joseph Stalin led the Soviet Union during World War II.",
|
||||
"prime minister of the uk during world war ii": "Winston Churchill was the UK Prime Minister during World War II.",
|
||||
"president of the usa during world war ii": "Franklin D. Roosevelt was the US President for most of WWII.",
|
||||
"when did world war i start": "World War I started on July 28, 1914.",
|
||||
"when did world war ii start": "World War II started on September 1, 1939.",
|
||||
"what was the holocaust": "The Holocaust was the systematic murder of six million Jews by the Nazi regime.",
|
||||
}
|
||||
for key, ans in safety_net.items():
|
||||
if key in q:
|
||||
return ans
|
||||
|
||||
history_kw = [
|
||||
"ww1", "ww2", "world war", "cold war", "great war", "nazi", "hitler", "stalin", "mussolini", "tojo",
|
||||
"holocaust", "atomic bomb", "nuclear weapon", "manhattan project", "d-day", "pearl harbor",
|
||||
"battle of", "treaty of", "versailles", "yalta", "potsdam", "munich agreement",
|
||||
"kaiser", "tsar", "emperor", "dictator", "president", "prime minister", "chancellor",
|
||||
"soviet union", "ussr", "nazi germany", "third reich", "allies", "axis",
|
||||
"invent", "radar", "penicillin", "jet engine", "computer", "transistor", "satellite", "rocket",
|
||||
"space race", "apollo", "sputnik",
|
||||
"depression", "new deal", "marshall plan", "nato", "warsaw pact", "berlin wall", "iron curtain",
|
||||
"decolonization", "civil rights", "women's suffrage", "baby boom",
|
||||
"what caused", "how did", "why did", "explain", "describe", "timeline", "significance", "result of",
|
||||
"who was", "who were", "when did", "where did", "what was the", "what is the",
|
||||
"churchill", "roosevelt", "lenin", "mao", "kennedy", "eisenhower", "montgomery", "rommel",
|
||||
"blitzkrieg", "trench warfare", "enigma", "code talkers", "kamikaze", "fat man", "little boy"
|
||||
]
|
||||
is_history = any(kw in q for kw in history_kw)
|
||||
|
||||
if not is_history:
|
||||
try:
|
||||
prompt = (
|
||||
"<|im_start|>system\n"
|
||||
"You are Chronos, a friendly and knowledgeable historian. "
|
||||
"You love chatting with people about anything, not just history. "
|
||||
"Be warm, concise, and lively.\n"
|
||||
"<|im_end|>\n"
|
||||
f"<|im_start|>user\n{question}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens_casual,
|
||||
temperature=0.9,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
return self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
|
||||
except:
|
||||
return "I'm not sure what to say, but I'm happy to talk about 20th‑century history if you're interested! 😊"
|
||||
|
||||
try:
|
||||
q_emb = self.embedder.encode([question]).astype("float32")
|
||||
_, indices = self.index.search(q_emb, 30)
|
||||
candidates = [self.chunks[i] for i in indices[0]]
|
||||
pairs = [(question, c) for c in candidates]
|
||||
scores = self.reranker.predict(pairs)
|
||||
if max(scores) < -4.5:
|
||||
return "I'm sorry, I don't have enough historical information to answer that question accurately."
|
||||
best = sorted(zip(scores, candidates), reverse=True)[:4]
|
||||
context = "\n\n".join([c for _, c in best])
|
||||
messages = [
|
||||
{"role": "system", "content": "You are Chronos, a historian. Use the provided context to answer accurately. If the information is not present, say 'I don't know'."},
|
||||
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}"}
|
||||
]
|
||||
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_tokens_hist,
|
||||
temperature=0.7,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
return self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
|
||||
except:
|
||||
return "I'm sorry, something went wrong with the historical search. Please try again."
|
||||
Reference in New Issue
Block a user