Files

26 lines
729 B
Python
Raw Permalink Normal View History

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(
path,
use_fast=True,
tokenizer_class=None
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto"
)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer
)
def __call__(self, data):
inputs = data.pop("inputs", data)
result = self.pipe(inputs, max_new_tokens=512)
return result