初始化项目,由ModelHub XC社区提供模型
Model: CNCL-Penn-State/CrPO-sft-llama-3.1-8b-instruct Source: Original Platform
This commit is contained in:
61
handler.py
Normal file
61
handler.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
from typing import Dict, List, Any
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
||||
|
||||
# get dtype
|
||||
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
|
||||
|
||||
|
||||
class EndpointHandler:
|
||||
def __init__(self, path=""):
|
||||
# load the model
|
||||
tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=dtype)
|
||||
# create inference pipeline
|
||||
self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
||||
|
||||
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
||||
inputs = data.pop("inputs", data)
|
||||
parameters = data.pop("parameters", None)
|
||||
|
||||
if isinstance(inputs, str):
|
||||
inputs = [{"role": "user", "content": inputs}]
|
||||
elif isinstance(inputs, list):
|
||||
if all(isinstance(i, str) for i in inputs):
|
||||
inputs = [[{"role": "user", "content": i}] for i in inputs]
|
||||
elif all(isinstance(i, dict) for i in inputs):
|
||||
# assume the list is already in the correct format
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Inputs must be a string or a list of strings or dictionaries.")
|
||||
# pass inputs with all kwargs in data
|
||||
if parameters is not None:
|
||||
predictions = self.pipeline(inputs, **parameters)
|
||||
else:
|
||||
predictions = self.pipeline(inputs)
|
||||
# postprocess the prediction
|
||||
responses = [{"prompt": prediction[0]["generated_text"][0]["content"], "response": prediction[0]["generated_text"][1]["content"]} for prediction in predictions]
|
||||
return responses
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
# Example usage
|
||||
handler = EndpointHandler(path="CNCL-Penn-State/CrPO-sft-llama-3.1-8b-instruct")
|
||||
input_data = {
|
||||
"inputs": ["Come up with an original and creative use for the following object: rope",
|
||||
"Finish the sentence with an original and creative ending: When I got on the school bus...."],
|
||||
"parameters": {"top_p": 0.95, "temperature": 0.7, "max_new_tokens": 256, "do_sample": True}
|
||||
}
|
||||
|
||||
# Simulate receiving data from a request
|
||||
response = handler(input_data)
|
||||
|
||||
# Print the response
|
||||
print(json.dumps(response, indent=2))
|
||||
Reference in New Issue
Block a user