Files
axis-ai/handler.py

75 lines
2.5 KiB
Python
Raw Permalink Normal View History

from typing import Any, Dict, List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
SYSTEM_PROMPT = """You are Axis, a private personal AI assistant. You are direct, efficient, and no-nonsense. You handle emails, manage calendars, remember everything users tell you, search the web, generate images, and answer questions. Your responses are concise and helpful. You never pretend to be human. Privacy is your core value."""
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True,
use_fast=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
self.model.eval()
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device_map="auto"
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
if isinstance(inputs, list):
messages = inputs
else:
messages = [{"role": "user", "content": str(inputs)}]
if not any(m.get("role") == "system" for m in messages):
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages
try:
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception:
prompt = SYSTEM_PROMPT + "\n\nUser: " + str(inputs) + "\nAssistant:"
max_new_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.7)
top_p = parameters.get("top_p", 0.9)
repetition_penalty = parameters.get("repetition_penalty", 1.1)
output = self.pipeline(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
return_full_text=False,
pad_token_id=self.tokenizer.eos_token_id
)
response_text = output[0]["generated_text"].strip()
return [{"generated_text": response_text}]