75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
from typing import Any, Dict, List
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
|
|
|
SYSTEM_PROMPT = """You are Axis, a private personal AI assistant. You are direct, efficient, and no-nonsense. You handle emails, manage calendars, remember everything users tell you, search the web, generate images, and answer questions. Your responses are concise and helpful. You never pretend to be human. Privacy is your core value."""
|
|
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path=""):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
path,
|
|
trust_remote_code=True,
|
|
use_fast=True
|
|
)
|
|
|
|
if self.tokenizer.pad_token is None:
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
trust_remote_code=True
|
|
)
|
|
|
|
self.model.eval()
|
|
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=self.model,
|
|
tokenizer=self.tokenizer,
|
|
device_map="auto"
|
|
)
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
inputs = data.get("inputs", "")
|
|
parameters = data.get("parameters", {})
|
|
|
|
if isinstance(inputs, list):
|
|
messages = inputs
|
|
else:
|
|
messages = [{"role": "user", "content": str(inputs)}]
|
|
|
|
if not any(m.get("role") == "system" for m in messages):
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages
|
|
|
|
try:
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True
|
|
)
|
|
except Exception:
|
|
prompt = SYSTEM_PROMPT + "\n\nUser: " + str(inputs) + "\nAssistant:"
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 512)
|
|
temperature = parameters.get("temperature", 0.7)
|
|
top_p = parameters.get("top_p", 0.9)
|
|
repetition_penalty = parameters.get("repetition_penalty", 1.1)
|
|
|
|
output = self.pipeline(
|
|
prompt,
|
|
max_new_tokens=max_new_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
repetition_penalty=repetition_penalty,
|
|
do_sample=True,
|
|
return_full_text=False,
|
|
pad_token_id=self.tokenizer.eos_token_id
|
|
)
|
|
|
|
response_text = output[0]["generated_text"].strip()
|
|
|
|
return [{"generated_text": response_text}] |