Files
mistral-7b-a2ui/handler.py

28 lines
812 B
Python
Raw Normal View History

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
)
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
)
def __call__(self, data):
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
if not parameters.get("max_new_tokens"):
parameters["max_new_tokens"] = 2048
result = self.pipeline(inputs, **parameters)
return result