Files
broken-model-fixed/handler.py

37 lines
1.2 KiB
Python
Raw Permalink Normal View History

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto", # loads directly to GPU, skips CPU staging
low_cpu_mem_usage=True,
)
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
)
def __call__(self, data):
messages = data.get("inputs", data.get("messages", []))
parameters = data.get("parameters", {})
max_new_tokens = parameters.get("max_new_tokens", 512)
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
result = self.pipeline(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.6,
top_p=0.95,
return_full_text=False,
)
return {"generated_text": result[0]["generated_text"]}