初始化项目,由ModelHub XC社区提供模型
Model: Achiraf01/mistral-immigration-canada-final Source: Original Platform
This commit is contained in:
64
handler.py
Normal file
64
handler.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# handler.py
|
||||
from typing import Any, Dict
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
|
||||
class EndpointHandler:
|
||||
def __init__(self, path: str = ""):
|
||||
# Quantization 8-bit → réduit ~14 GB à ~7 GB VRAM
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path,
|
||||
quantization_config=bnb_config,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Any:
|
||||
inputs = data.get("inputs", "")
|
||||
parameters = data.get("parameters", {})
|
||||
|
||||
max_new_tokens = parameters.get("max_new_tokens", 512)
|
||||
temperature = parameters.get("temperature", 0.3)
|
||||
repetition_penalty = parameters.get("repetition_penalty", 1.1)
|
||||
return_full_text = parameters.get("return_full_text", False)
|
||||
|
||||
tokenized = self.tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=2048,
|
||||
).to("cuda")
|
||||
|
||||
# ✅ FIX : Mistral n'utilise pas token_type_ids
|
||||
tokenized.pop("token_type_ids", None)
|
||||
|
||||
with torch.no_grad():
|
||||
output_ids = self.model.generate(
|
||||
**tokenized,
|
||||
max_new_tokens=max_new_tokens,
|
||||
temperature=temperature,
|
||||
repetition_penalty=repetition_penalty,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
# Retirer le prompt si return_full_text=False
|
||||
if not return_full_text:
|
||||
input_len = tokenized["input_ids"].shape[1]
|
||||
output_ids = output_ids[:, input_len:]
|
||||
|
||||
generated = self.tokenizer.decode(
|
||||
output_ids[0],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
return [{"generated_text": generated}]
|
||||
Reference in New Issue
Block a user