72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
|
|
import torch
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||
|
|
from typing import Dict, List, Any
|
||
|
|
|
||
|
|
class EndpointHandler:
|
||
|
|
def __init__(self, path=""):
|
||
|
|
"""
|
||
|
|
Initializes the model and tokenizer.
|
||
|
|
`path` is automatically provided by Hugging Face (it points to your repo files).
|
||
|
|
"""
|
||
|
|
print("🚀 Initializing PropagationShield Handler...")
|
||
|
|
|
||
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||
|
|
|
||
|
|
# 1. Configure 4-bit quantization to prevent OOM and System RAM limits
|
||
|
|
bnb_config = BitsAndBytesConfig(
|
||
|
|
load_in_4bit=True,
|
||
|
|
bnb_4bit_use_double_quant=True,
|
||
|
|
bnb_4bit_quant_type="nf4",
|
||
|
|
bnb_4bit_compute_dtype=torch.float16
|
||
|
|
)
|
||
|
|
|
||
|
|
# 2. Load the model safely
|
||
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
path,
|
||
|
|
quantization_config=bnb_config,
|
||
|
|
device_map="auto",
|
||
|
|
torch_dtype=torch.float16,
|
||
|
|
low_cpu_mem_usage=True, # Crucial to prevent the 30GB RAM crash during boot
|
||
|
|
)
|
||
|
|
print("✅ PropagationShield Loaded Successfully!")
|
||
|
|
|
||
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Runs inference on the incoming request.
|
||
|
|
"""
|
||
|
|
# Parse incoming data
|
||
|
|
inputs = data.pop("inputs", data)
|
||
|
|
parameters = data.pop("parameters", {})
|
||
|
|
|
||
|
|
max_new_tokens = parameters.get("max_new_tokens", 512)
|
||
|
|
temperature = parameters.get("temperature", 0.1)
|
||
|
|
|
||
|
|
# 3. Format the prompt
|
||
|
|
# If the user sends a list of messages [{"role": "system", "content": "..."}, ...]
|
||
|
|
if isinstance(inputs, list):
|
||
|
|
prompt = self.tokenizer.apply_chat_template(
|
||
|
|
inputs, tokenize=False, add_generation_prompt=True
|
||
|
|
)
|
||
|
|
# If the user sends a raw formatted string
|
||
|
|
else:
|
||
|
|
prompt = str(inputs)
|
||
|
|
|
||
|
|
# 4. Tokenize
|
||
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
|
||
|
|
|
||
|
|
# 5. Generate
|
||
|
|
with torch.no_grad():
|
||
|
|
output_ids = self.model.generate(
|
||
|
|
input_ids,
|
||
|
|
max_new_tokens=max_new_tokens,
|
||
|
|
temperature=temperature,
|
||
|
|
do_sample=True if temperature > 0.0 else False,
|
||
|
|
pad_token_id=self.tokenizer.eos_token_id
|
||
|
|
)
|
||
|
|
|
||
|
|
# 6. Isolate and decode only the newly generated tokens
|
||
|
|
generated_ids = output_ids[0][input_ids.shape[-1]:]
|
||
|
|
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||
|
|
|
||
|
|
# Return in standard HF API format
|
||
|
|
return [{"generated_text": generated_text.strip()}]
|