40 lines
1.2 KiB
Python
40 lines
1.2 KiB
Python
|
|
from typing import Any
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from transformers import (
|
||
|
|
AutoModelForCausalLM,
|
||
|
|
AutoTokenizer,
|
||
|
|
BitsAndBytesConfig,
|
||
|
|
pipeline,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class EndpointHandler:
|
||
|
|
def __init__(self, path: str = ""):
|
||
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
path,
|
||
|
|
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||
|
|
device_map="auto",
|
||
|
|
)
|
||
|
|
self.pipe = pipeline(
|
||
|
|
"text-generation",
|
||
|
|
model=self.model,
|
||
|
|
tokenizer=self.tokenizer,
|
||
|
|
)
|
||
|
|
|
||
|
|
def __call__(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
||
|
|
inputs = data.pop("inputs", data)
|
||
|
|
parameters = data.pop("parameters", {})
|
||
|
|
|
||
|
|
gen_kwargs = {
|
||
|
|
"max_new_tokens": parameters.get("max_new_tokens", 256),
|
||
|
|
"temperature": parameters.get("temperature", 0.8),
|
||
|
|
"repetition_penalty": parameters.get("repetition_penalty", 1.3),
|
||
|
|
"no_repeat_ngram_size": parameters.get("no_repeat_ngram_size", 4),
|
||
|
|
"do_sample": parameters.get("do_sample", True),
|
||
|
|
}
|
||
|
|
|
||
|
|
outputs = self.pipe(inputs, **gen_kwargs)
|
||
|
|
return outputs
|