import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from typing import Dict, Any class EndpointHandler: def __init__(self, path=""): quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", ) self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, quantization_config=quantization_config, device_map="auto", torch_dtype=torch.float16, ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data.get("inputs", "") parameters = data.get("parameters", {}) if isinstance(inputs, list): text = self.tokenizer.apply_chat_template(inputs, tokenize=False, add_generation_prompt=True) else: text = inputs encoded = self.tokenizer(text, return_tensors="pt").to(self.model.device) # Remove token_type_ids - not used by Mistral models encoded.pop("token_type_ids", None) max_new_tokens = parameters.get("max_new_tokens", 512) temperature = parameters.get("temperature", 0.7) with torch.no_grad(): outputs = self.model.generate( **encoded, max_new_tokens=max_new_tokens, temperature=max(temperature, 0.01), do_sample=temperature > 0, ) new_tokens = outputs[0][encoded["input_ids"].shape[1]:] response = self.tokenizer.decode(new_tokens, skip_special_tokens=True) return [{"generated_text": response}]