41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
from typing import Any
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, LlamaTokenizerFast, pipeline
|
|
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path: str = ""):
|
|
self.tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
path,
|
|
torch_dtype=torch.bfloat16,
|
|
device_map="auto",
|
|
)
|
|
self.pipe = pipeline(
|
|
"text-generation",
|
|
model=self.model,
|
|
tokenizer=self.tokenizer,
|
|
)
|
|
|
|
def __call__(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
|
inputs = data.get("inputs", "")
|
|
parameters = data.get("parameters", {})
|
|
|
|
if isinstance(inputs, list):
|
|
prompt = self.tokenizer.apply_chat_template(
|
|
inputs, tokenize=False, add_generation_prompt=True
|
|
)
|
|
else:
|
|
prompt = inputs
|
|
|
|
outputs = self.pipe(
|
|
prompt,
|
|
max_new_tokens=parameters.get("max_new_tokens", 256),
|
|
temperature=parameters.get("temperature", 0.7),
|
|
top_p=parameters.get("top_p", 0.9),
|
|
do_sample=parameters.get("do_sample", True),
|
|
return_full_text=False,
|
|
)
|
|
|
|
return outputs
|