26 lines
992 B
Python
26 lines
992 B
Python
import torch
|
|
from typing import Dict, List, Any
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
|
|
|
|
|
class EndpointHandler():
|
|
def __init__(self, path=""):
|
|
model = AutoModelForCausalLM.from_pretrained("hyperspaceai/hyperEngine_phi3_128k", device_map="auto", torch_dtype="auto", trust_remote_code=True)
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
|
self.pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
|
|
|
def __call__(self, data:Dict[str, Any]) :
|
|
messages = data.pop("messages", None)
|
|
generation_args = data.pop("generation_args", None)
|
|
|
|
if generation_args==None :
|
|
generation_args = {
|
|
"max_new_tokens": 500,
|
|
"return_full_text": False,
|
|
"temperature": 0.0,
|
|
"do_sample": False,
|
|
}
|
|
|
|
output = self.pipe(messages, **generation_args)
|
|
return output[0]['generated_text']
|