64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
# handler.py
|
|
from typing import Any, Dict
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
|
|
|
class EndpointHandler:
|
|
def __init__(self, path: str = ""):
|
|
# Quantization 8-bit → réduit ~14 GB à ~7 GB VRAM
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_8bit=True,
|
|
)
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
path,
|
|
quantization_config=bnb_config,
|
|
device_map="auto",
|
|
torch_dtype=torch.float16,
|
|
)
|
|
self.model.eval()
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any:
|
|
inputs = data.get("inputs", "")
|
|
parameters = data.get("parameters", {})
|
|
|
|
max_new_tokens = parameters.get("max_new_tokens", 512)
|
|
temperature = parameters.get("temperature", 0.3)
|
|
repetition_penalty = parameters.get("repetition_penalty", 1.1)
|
|
return_full_text = parameters.get("return_full_text", False)
|
|
|
|
tokenized = self.tokenizer(
|
|
inputs,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=2048,
|
|
).to("cuda")
|
|
|
|
# ✅ FIX : Mistral n'utilise pas token_type_ids
|
|
tokenized.pop("token_type_ids", None)
|
|
|
|
with torch.no_grad():
|
|
output_ids = self.model.generate(
|
|
**tokenized,
|
|
max_new_tokens=max_new_tokens,
|
|
temperature=temperature,
|
|
repetition_penalty=repetition_penalty,
|
|
do_sample=True,
|
|
pad_token_id=self.tokenizer.eos_token_id,
|
|
)
|
|
|
|
# Retirer le prompt si return_full_text=False
|
|
if not return_full_text:
|
|
input_len = tokenized["input_ids"].shape[1]
|
|
output_ids = output_ids[:, input_len:]
|
|
|
|
generated = self.tokenizer.decode(
|
|
output_ids[0],
|
|
skip_special_tokens=True,
|
|
)
|
|
|
|
return [{"generated_text": generated}] |