初始化项目,由ModelHub XC社区提供模型
Model: vrutkovs/Lusterka-7B Source: Original Platform
This commit is contained in:
39
handler.py
Normal file
39
handler.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
|
||||
class EndpointHandler:
|
||||
def __init__(self, path: str = ""):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
device_map="auto",
|
||||
)
|
||||
self.pipe = pipeline(
|
||||
"text-generation",
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
def __call__(self, data: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
inputs = data.pop("inputs", data)
|
||||
parameters = data.pop("parameters", {})
|
||||
|
||||
gen_kwargs = {
|
||||
"max_new_tokens": parameters.get("max_new_tokens", 256),
|
||||
"temperature": parameters.get("temperature", 0.8),
|
||||
"repetition_penalty": parameters.get("repetition_penalty", 1.3),
|
||||
"no_repeat_ngram_size": parameters.get("no_repeat_ngram_size", 4),
|
||||
"do_sample": parameters.get("do_sample", True),
|
||||
}
|
||||
|
||||
outputs = self.pipe(inputs, **gen_kwargs)
|
||||
return outputs
|
||||
Reference in New Issue
Block a user