61 lines
2.5 KiB
Python
61 lines
2.5 KiB
Python
|
|
import torch
|
||
|
|
from typing import Dict, List, Any
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
||
|
|
|
||
|
|
# get dtype
|
||
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
||
|
|
|
||
|
|
|
||
|
|
class EndpointHandler:
|
||
|
|
def __init__(self, path=""):
|
||
|
|
# load the model
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(path)
|
||
|
|
tokenizer.padding_side = "left"
|
||
|
|
|
||
|
|
if tokenizer.pad_token is None:
|
||
|
|
tokenizer.pad_token = tokenizer.eos_token
|
||
|
|
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=dtype)
|
||
|
|
# create inference pipeline
|
||
|
|
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||
|
|
|
||
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
||
|
|
inputs = data.pop("inputs", data)
|
||
|
|
parameters = data.pop("parameters", None)
|
||
|
|
|
||
|
|
if isinstance(inputs, str):
|
||
|
|
inputs = [{"role": "user", "content": inputs}]
|
||
|
|
elif isinstance(inputs, list):
|
||
|
|
if all(isinstance(i, str) for i in inputs):
|
||
|
|
inputs = [[{"role": "user", "content": i}] for i in inputs]
|
||
|
|
elif all(isinstance(i, dict) for i in inputs):
|
||
|
|
# assume the list is already in the correct format
|
||
|
|
pass
|
||
|
|
else:
|
||
|
|
raise ValueError("Inputs must be a string or a list of strings or dictionaries.")
|
||
|
|
# pass inputs with all kwargs in data
|
||
|
|
if parameters is not None:
|
||
|
|
predictions = self.pipeline(inputs, **parameters)
|
||
|
|
else:
|
||
|
|
predictions = self.pipeline(inputs)
|
||
|
|
# postprocess the prediction
|
||
|
|
responses = [{"prompt": prediction[0]["generated_text"][0]["content"], "response": prediction[0]["generated_text"][1]["content"]} for prediction in predictions]
|
||
|
|
return responses
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
import json
|
||
|
|
from dotenv import load_dotenv
|
||
|
|
load_dotenv()
|
||
|
|
# Example usage
|
||
|
|
handler = EndpointHandler(path="CNCL-Penn-State/CrPO-sft-llama-3.1-8b-instruct")
|
||
|
|
input_data = {
|
||
|
|
"inputs": ["Come up with an original and creative use for the following object: rope",
|
||
|
|
"Finish the sentence with an original and creative ending: When I got on the school bus...."],
|
||
|
|
"parameters": {"top_p": 0.95, "temperature": 0.7, "max_new_tokens": 256, "do_sample": True}
|
||
|
|
}
|
||
|
|
|
||
|
|
# Simulate receiving data from a request
|
||
|
|
response = handler(input_data)
|
||
|
|
|
||
|
|
# Print the response
|
||
|
|
print(json.dumps(response, indent=2))
|