30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
|
import torch
|
|
|
|
class EndpointHandler():
|
|
def __init__(self, path=""):
|
|
# 1. Load the tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(path)
|
|
|
|
# 2. Load the model using accelerate, letting it handle the device placement
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
path,
|
|
device_map="auto",
|
|
torch_dtype=torch.float16 # Strongly recommended for a 7B model to save VRAM
|
|
)
|
|
|
|
# 3. Create the pipeline WITHOUT the 'device' argument
|
|
self.pipeline = pipeline(
|
|
"text-generation",
|
|
model=model,
|
|
tokenizer=tokenizer
|
|
)
|
|
|
|
def __call__(self, data):
|
|
# 4. Handle incoming requests from your Next.js dashboard
|
|
inputs = data.pop("inputs", data)
|
|
parameters = data.pop("parameters", {})
|
|
|
|
# Generate prediction
|
|
prediction = self.pipeline(inputs, **parameters)
|
|
return prediction |