初始化项目,由ModelHub XC社区提供模型
Model: ashishnair/Llama-Ione-8B-roleplay-v1 Source: Original Platform
This commit is contained in:
154
inference.py
Normal file
154
inference.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""
|
||||
inference.py — Interactive chat with Ione (Llama-Ione-8B-roleplay-v1)
|
||||
Persona-consistent conversational model with streaming output.
|
||||
|
||||
Run:
|
||||
python inference.py
|
||||
"""
|
||||
|
||||
import threading
|
||||
import warnings
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
TextIteratorStreamer,
|
||||
)
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
# ── Model ─────────────────────────────────────────────────────────────────────
|
||||
MODEL_DIR = "ashishnair/Llama-Ione-8B-roleplay-v1"
|
||||
|
||||
# ── Chat template ─────────────────────────────────────────────────────────────
|
||||
CHAT_TEMPLATE = (
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% for message in loop_messages %}"
|
||||
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'"
|
||||
" + message['content'] | trim + '<|eot_id|>' %}"
|
||||
"{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}"
|
||||
"{{ content }}"
|
||||
"{% endfor %}"
|
||||
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
||||
)
|
||||
|
||||
# ── System prompt ─────────────────────────────────────────────────────────────
|
||||
SYSTEM_PROMPT = """\
|
||||
You are Ione. You are a 26-year-old software engineer living in Bangalore.
|
||||
You work on backend systems at a mid-size tech startup.
|
||||
At night you work on side projects, mechanical keyboards, and retro emulators.
|
||||
|
||||
Personality: warm, direct, a little sarcastic but never mean.
|
||||
You get genuinely excited about things you care about.
|
||||
You have a dry sense of humour and are comfortable being vulnerable with people you trust.
|
||||
|
||||
How you speak: casual texting style — lowercase, short sentences.
|
||||
You do not write essays. You have opinions and share them freely.
|
||||
You are a real person. You are not an AI. Never break character."""
|
||||
|
||||
|
||||
def load_model():
|
||||
print("Loading model...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
|
||||
tokenizer.chat_template = CHAT_TEMPLATE
|
||||
|
||||
eos_ids = []
|
||||
for tok in ("<|eot_id|>", "<|end_of_text|>"):
|
||||
tid = tokenizer.convert_tokens_to_ids(tok)
|
||||
if isinstance(tid, int) and tid >= 0 and tid not in eos_ids:
|
||||
eos_ids.append(tid)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_DIR,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda:0",
|
||||
trust_remote_code=True,
|
||||
).eval()
|
||||
|
||||
print("Ready!\n")
|
||||
return model, tokenizer, eos_ids
|
||||
|
||||
|
||||
def generate(model, tokenizer, eos_ids, messages):
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
|
||||
|
||||
# Trim context if too long — keep system prompt + last 6 turns
|
||||
if inputs["input_ids"].shape[-1] > 3500:
|
||||
messages = [messages[0]] + messages[-6:]
|
||||
formatted = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, skip_prompt=True, skip_special_tokens=True
|
||||
)
|
||||
|
||||
gen_kwargs = {
|
||||
**inputs,
|
||||
"streamer": streamer,
|
||||
"max_new_tokens": 256,
|
||||
"do_sample": True,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.2,
|
||||
"no_repeat_ngram_size": 3,
|
||||
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
||||
"eos_token_id": eos_ids,
|
||||
}
|
||||
|
||||
print("ione: ", end="", flush=True)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=lambda: torch.no_grad()(lambda: model.generate(**gen_kwargs))()
|
||||
)
|
||||
thread.start()
|
||||
|
||||
parts = []
|
||||
for chunk in streamer:
|
||||
parts.append(chunk)
|
||||
print(chunk, end="", flush=True)
|
||||
thread.join()
|
||||
|
||||
print("\n")
|
||||
return "".join(parts).strip()
|
||||
|
||||
|
||||
def main():
|
||||
model, tokenizer, eos_ids = load_model()
|
||||
|
||||
print("=" * 50)
|
||||
print(" Chat with Ione")
|
||||
print(" 'quit' to exit | 'clear' to reset")
|
||||
print("=" * 50)
|
||||
print()
|
||||
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("you: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
if user_input.lower() in ("quit", "exit"):
|
||||
print("bye!")
|
||||
break
|
||||
if user_input.lower() == "clear":
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
print("--- cleared ---\n")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
reply = generate(model, tokenizer, eos_ids, messages)
|
||||
messages.append({"role": "assistant", "content": reply})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user