# handler.py from typing import Any, Dict import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig class EndpointHandler: def __init__(self, path: str = ""): # Quantization 8-bit → réduit ~14 GB à ~7 GB VRAM bnb_config = BitsAndBytesConfig( load_in_8bit=True, ) self.tokenizer = AutoTokenizer.from_pretrained(path) self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( path, quantization_config=bnb_config, device_map="auto", torch_dtype=torch.float16, ) self.model.eval() def __call__(self, data: Dict[str, Any]) -> Any: inputs = data.get("inputs", "") parameters = data.get("parameters", {}) max_new_tokens = parameters.get("max_new_tokens", 512) temperature = parameters.get("temperature", 0.3) repetition_penalty = parameters.get("repetition_penalty", 1.1) return_full_text = parameters.get("return_full_text", False) tokenized = self.tokenizer( inputs, return_tensors="pt", padding=True, truncation=True, max_length=2048, ).to("cuda") # ✅ FIX : Mistral n'utilise pas token_type_ids tokenized.pop("token_type_ids", None) with torch.no_grad(): output_ids = self.model.generate( **tokenized, max_new_tokens=max_new_tokens, temperature=temperature, repetition_penalty=repetition_penalty, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, ) # Retirer le prompt si return_full_text=False if not return_full_text: input_len = tokenized["input_ids"].shape[1] output_ids = output_ids[:, input_len:] generated = self.tokenizer.decode( output_ids[0], skip_special_tokens=True, ) return [{"generated_text": generated}]