from typing import Any import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline, ) class EndpointHandler: def __init__(self, path: str = ""): self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto", ) self.pipe = pipeline( "text-generation", model=self.model, tokenizer=self.tokenizer, ) def __call__(self, data: dict[str, Any]) -> list[dict[str, Any]]: inputs = data.pop("inputs", data) parameters = data.pop("parameters", {}) gen_kwargs = { "max_new_tokens": parameters.get("max_new_tokens", 256), "temperature": parameters.get("temperature", 0.8), "repetition_penalty": parameters.get("repetition_penalty", 1.3), "no_repeat_ngram_size": parameters.get("no_repeat_ngram_size", 4), "do_sample": parameters.get("do_sample", True), } outputs = self.pipe(inputs, **gen_kwargs) return outputs