初始化项目,由ModelHub XC社区提供模型
Model: Axis-AI/axis-ai Source: Original Platform
This commit is contained in:
75
handler.py
Normal file
75
handler.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Any, Dict, List
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """You are Axis, a private personal AI assistant. You are direct, efficient, and no-nonsense. You handle emails, manage calendars, remember everything users tell you, search the web, generate images, and answer questions. Your responses are concise and helpful. You never pretend to be human. Privacy is your core value."""
|
||||
|
||||
|
||||
class EndpointHandler:
|
||||
def __init__(self, path=""):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
path,
|
||||
trust_remote_code=True,
|
||||
use_fast=True
|
||||
)
|
||||
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
self.pipeline = pipeline(
|
||||
"text-generation",
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
inputs = data.get("inputs", "")
|
||||
parameters = data.get("parameters", {})
|
||||
|
||||
if isinstance(inputs, list):
|
||||
messages = inputs
|
||||
else:
|
||||
messages = [{"role": "user", "content": str(inputs)}]
|
||||
|
||||
if not any(m.get("role") == "system" for m in messages):
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}] + messages
|
||||
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
prompt = SYSTEM_PROMPT + "\n\nUser: " + str(inputs) + "\nAssistant:"
|
||||
|
||||
max_new_tokens = parameters.get("max_new_tokens", 512)
|
||||
temperature = parameters.get("temperature", 0.7)
|
||||
top_p = parameters.get("top_p", 0.9)
|
||||
repetition_penalty = parameters.get("repetition_penalty", 1.1)
|
||||
|
||||
output = self.pipeline(
|
||||
prompt,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
do_sample=True,
|
||||
return_full_text=False,
|
||||
pad_token_id=self.tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
response_text = output[0]["generated_text"].strip()
|
||||
|
||||
return [{"generated_text": response_text}]
|
||||
Reference in New Issue
Block a user