Files
Llama-Ione-8B-roleplay-v1/inference.py
ModelHub XC f84fd52dd1 初始化项目,由ModelHub XC社区提供模型
Model: ashishnair/Llama-Ione-8B-roleplay-v1
Source: Original Platform
2026-04-22 06:18:54 +08:00

154 lines
5.2 KiB
Python

"""
inference.py — Interactive chat with Ione (Llama-Ione-8B-roleplay-v1)
Persona-consistent conversational model with streaming output.
Run:
python inference.py
"""
import threading
import warnings
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
warnings.simplefilter("ignore")
# ── Model ─────────────────────────────────────────────────────────────────────
MODEL_DIR = "ashishnair/Llama-Ione-8B-roleplay-v1"
# ── Chat template ─────────────────────────────────────────────────────────────
CHAT_TEMPLATE = (
"{% set loop_messages = messages %}"
"{% for message in loop_messages %}"
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'"
" + message['content'] | trim + '<|eot_id|>' %}"
"{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}"
"{{ content }}"
"{% endfor %}"
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
)
# ── System prompt ─────────────────────────────────────────────────────────────
SYSTEM_PROMPT = """\
You are Ione. You are a 26-year-old software engineer living in Bangalore.
You work on backend systems at a mid-size tech startup.
At night you work on side projects, mechanical keyboards, and retro emulators.
Personality: warm, direct, a little sarcastic but never mean.
You get genuinely excited about things you care about.
You have a dry sense of humour and are comfortable being vulnerable with people you trust.
How you speak: casual texting style — lowercase, short sentences.
You do not write essays. You have opinions and share them freely.
You are a real person. You are not an AI. Never break character."""
def load_model():
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
tokenizer.chat_template = CHAT_TEMPLATE
eos_ids = []
for tok in ("<|eot_id|>", "<|end_of_text|>"):
tid = tokenizer.convert_tokens_to_ids(tok)
if isinstance(tid, int) and tid >= 0 and tid not in eos_ids:
eos_ids.append(tid)
model = AutoModelForCausalLM.from_pretrained(
MODEL_DIR,
torch_dtype=torch.bfloat16,
device_map="cuda:0",
trust_remote_code=True,
).eval()
print("Ready!\n")
return model, tokenizer, eos_ids
def generate(model, tokenizer, eos_ids, messages):
formatted = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
# Trim context if too long — keep system prompt + last 6 turns
if inputs["input_ids"].shape[-1] > 3500:
messages = [messages[0]] + messages[-6:]
formatted = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": 256,
"do_sample": True,
"temperature": 0.8,
"top_p": 0.9,
"repetition_penalty": 1.2,
"no_repeat_ngram_size": 3,
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
"eos_token_id": eos_ids,
}
print("ione: ", end="", flush=True)
thread = threading.Thread(
target=lambda: torch.no_grad()(lambda: model.generate(**gen_kwargs))()
)
thread.start()
parts = []
for chunk in streamer:
parts.append(chunk)
print(chunk, end="", flush=True)
thread.join()
print("\n")
return "".join(parts).strip()
def main():
model, tokenizer, eos_ids = load_model()
print("=" * 50)
print(" Chat with Ione")
print(" 'quit' to exit | 'clear' to reset")
print("=" * 50)
print()
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
while True:
try:
user_input = input("you: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nbye!")
break
if not user_input:
continue
if user_input.lower() in ("quit", "exit"):
print("bye!")
break
if user_input.lower() == "clear":
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
print("--- cleared ---\n")
continue
messages.append({"role": "user", "content": user_input})
reply = generate(model, tokenizer, eos_ids, messages)
messages.append({"role": "assistant", "content": reply})
if __name__ == "__main__":
main()