37 lines
1.2 KiB
Python
37 lines
1.2 KiB
Python
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
|
import torch
|
|
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path=""):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
path,
|
|
torch_dtype=torch.float16,
|
|
device_map="auto", # loads directly to GPU, skips CPU staging
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=self.model,
|
|
tokenizer=self.tokenizer,
|
|
)
|
|
|
|
def __call__(self, data):
|
|
messages = data.get("inputs", data.get("messages", []))
|
|
parameters = data.get("parameters", {})
|
|
max_new_tokens = parameters.get("max_new_tokens", 512)
|
|
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
result = self.pipeline(
|
|
prompt,
|
|
max_new_tokens=max_new_tokens,
|
|
do_sample=True,
|
|
temperature=0.6,
|
|
top_p=0.95,
|
|
return_full_text=False,
|
|
)
|
|
return {"generated_text": result[0]["generated_text"]}
|