Files
ModelHub XC e84377754d 初始化项目,由ModelHub XC社区提供模型
Model: round-bird/georgia-sports-llama3-sft
Source: Original Platform
2026-06-16 07:46:16 +08:00

41 lines
1.2 KiB
Python

from typing import Any
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizerFast, pipeline
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = LlamaTokenizerFast.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
)
def __call__(self, data: dict[str, Any]) -> list[dict[str, Any]]:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
if isinstance(inputs, list):
prompt = self.tokenizer.apply_chat_template(
inputs, tokenize=False, add_generation_prompt=True
)
else:
prompt = inputs
outputs = self.pipe(
prompt,
max_new_tokens=parameters.get("max_new_tokens", 256),
temperature=parameters.get("temperature", 0.7),
top_p=parameters.get("top_p", 0.9),
do_sample=parameters.get("do_sample", True),
return_full_text=False,
)
return outputs