154 lines
5.2 KiB
Python
154 lines
5.2 KiB
Python
"""
|
|
inference.py — Interactive chat with Ione (Llama-Ione-8B-roleplay-v1)
|
|
Persona-consistent conversational model with streaming output.
|
|
|
|
Run:
|
|
python inference.py
|
|
"""
|
|
|
|
import threading
|
|
import warnings
|
|
import torch
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
TextIteratorStreamer,
|
|
)
|
|
|
|
warnings.simplefilter("ignore")
|
|
|
|
# ── Model ─────────────────────────────────────────────────────────────────────
|
|
MODEL_DIR = "ashishnair/Llama-Ione-8B-roleplay-v1"
|
|
|
|
# ── Chat template ─────────────────────────────────────────────────────────────
|
|
CHAT_TEMPLATE = (
|
|
"{% set loop_messages = messages %}"
|
|
"{% for message in loop_messages %}"
|
|
"{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'"
|
|
" + message['content'] | trim + '<|eot_id|>' %}"
|
|
"{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}"
|
|
"{{ content }}"
|
|
"{% endfor %}"
|
|
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"
|
|
)
|
|
|
|
# ── System prompt ─────────────────────────────────────────────────────────────
|
|
SYSTEM_PROMPT = """\
|
|
You are Ione. You are a 26-year-old software engineer living in Bangalore.
|
|
You work on backend systems at a mid-size tech startup.
|
|
At night you work on side projects, mechanical keyboards, and retro emulators.
|
|
|
|
Personality: warm, direct, a little sarcastic but never mean.
|
|
You get genuinely excited about things you care about.
|
|
You have a dry sense of humour and are comfortable being vulnerable with people you trust.
|
|
|
|
How you speak: casual texting style — lowercase, short sentences.
|
|
You do not write essays. You have opinions and share them freely.
|
|
You are a real person. You are not an AI. Never break character."""
|
|
|
|
|
|
def load_model():
|
|
print("Loading model...")
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)
|
|
tokenizer.chat_template = CHAT_TEMPLATE
|
|
|
|
eos_ids = []
|
|
for tok in ("<|eot_id|>", "<|end_of_text|>"):
|
|
tid = tokenizer.convert_tokens_to_ids(tok)
|
|
if isinstance(tid, int) and tid >= 0 and tid not in eos_ids:
|
|
eos_ids.append(tid)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_DIR,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="cuda:0",
|
|
trust_remote_code=True,
|
|
).eval()
|
|
|
|
print("Ready!\n")
|
|
return model, tokenizer, eos_ids
|
|
|
|
|
|
def generate(model, tokenizer, eos_ids, messages):
|
|
formatted = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
|
|
|
|
# Trim context if too long — keep system prompt + last 6 turns
|
|
if inputs["input_ids"].shape[-1] > 3500:
|
|
messages = [messages[0]] + messages[-6:]
|
|
formatted = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
|
|
|
|
streamer = TextIteratorStreamer(
|
|
tokenizer, skip_prompt=True, skip_special_tokens=True
|
|
)
|
|
|
|
gen_kwargs = {
|
|
**inputs,
|
|
"streamer": streamer,
|
|
"max_new_tokens": 256,
|
|
"do_sample": True,
|
|
"temperature": 0.8,
|
|
"top_p": 0.9,
|
|
"repetition_penalty": 1.2,
|
|
"no_repeat_ngram_size": 3,
|
|
"pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id,
|
|
"eos_token_id": eos_ids,
|
|
}
|
|
|
|
print("ione: ", end="", flush=True)
|
|
|
|
thread = threading.Thread(
|
|
target=lambda: torch.no_grad()(lambda: model.generate(**gen_kwargs))()
|
|
)
|
|
thread.start()
|
|
|
|
parts = []
|
|
for chunk in streamer:
|
|
parts.append(chunk)
|
|
print(chunk, end="", flush=True)
|
|
thread.join()
|
|
|
|
print("\n")
|
|
return "".join(parts).strip()
|
|
|
|
|
|
def main():
|
|
model, tokenizer, eos_ids = load_model()
|
|
|
|
print("=" * 50)
|
|
print(" Chat with Ione")
|
|
print(" 'quit' to exit | 'clear' to reset")
|
|
print("=" * 50)
|
|
print()
|
|
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
|
|
|
while True:
|
|
try:
|
|
user_input = input("you: ").strip()
|
|
except (EOFError, KeyboardInterrupt):
|
|
print("\nbye!")
|
|
break
|
|
|
|
if not user_input:
|
|
continue
|
|
if user_input.lower() in ("quit", "exit"):
|
|
print("bye!")
|
|
break
|
|
if user_input.lower() == "clear":
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
|
print("--- cleared ---\n")
|
|
continue
|
|
|
|
messages.append({"role": "user", "content": user_input})
|
|
reply = generate(model, tokenizer, eos_ids, messages)
|
|
messages.append({"role": "assistant", "content": reply})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |