28 lines
812 B
Python
28 lines
812 B
Python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
import torch
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path=""):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
path,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto",
|
|
)
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=self.model,
|
|
tokenizer=self.tokenizer,
|
|
)
|
|
|
|
def __call__(self, data):
|
|
inputs = data.get("inputs", "")
|
|
parameters = data.get("parameters", {})
|
|
|
|
if not parameters.get("max_new_tokens"):
|
|
parameters["max_new_tokens"] = 2048
|
|
|
|
result = self.pipeline(inputs, **parameters)
|
|
return result
|