Files
ModelHub XC 46d4f4ab5f 初始化项目,由ModelHub XC社区提供模型
Model: Achiraf01/mistral-immigration-canada-final
Source: Original Platform
2026-04-19 12:45:39 +08:00

64 lines
2.1 KiB
Python

# handler.py
from typing import Any, Dict
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
class EndpointHandler:
def __init__(self, path: str = ""):
# Quantization 8-bit → réduit ~14 GB à ~7 GB VRAM
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
path,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Any:
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
max_new_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.3)
repetition_penalty = parameters.get("repetition_penalty", 1.1)
return_full_text = parameters.get("return_full_text", False)
tokenized = self.tokenizer(
inputs,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048,
).to("cuda")
# ✅ FIX : Mistral n'utilise pas token_type_ids
tokenized.pop("token_type_ids", None)
with torch.no_grad():
output_ids = self.model.generate(
**tokenized,
max_new_tokens=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
)
# Retirer le prompt si return_full_text=False
if not return_full_text:
input_len = tokenized["input_ids"].shape[1]
output_ids = output_ids[:, input_len:]
generated = self.tokenizer.decode(
output_ids[0],
skip_special_tokens=True,
)
return [{"generated_text": generated}]