Files
AtmicIntelv1/handler.py
ModelHub XC 1d9bdf418f 初始化项目,由ModelHub XC社区提供模型
Model: SriRamanaAtmic/AtmicIntelv1
Source: Original Platform
2026-06-04 13:22:34 +08:00

111 lines
5.8 KiB
Python

"""
handler.py — HuggingFace Inference Endpoint handler for SriRamanaAtmic/AtmicIntelv1
Compatible with transformers==4.51.3 (matches model's transformers_version in config.json).
Generation parameters (from Section A4 of technical review — do not change):
do_sample = False (greedy decoding — matches SFT + DPO training exactly)
max_new_tokens = 350
repetition_penalty = 1.0 (sole repetition control)
no_repeat_ngram_size = 0 (PERMANENTLY DISABLED — hard-coded, not overridable via API)
temperature / top_p (REMOVED — inactive under greedy decoding)
Token IDs (from added_tokens.json — verified):
<|endoftext|> = 32000 (pad_token_id)
<|assistant|> = 32001 (appears in INPUT prompt — must NEVER be eos_token_id)
<|end|> = 32007 (turn terminator — correct eos for generation)
Critical: generation_config.json in the repo contains eos_token_id=[32000, 32001, 32007].
Token 32001 (<|assistant|>) is present in every input prompt, causing generation to stop
at token 0. This handler explicitly overrides generation_config.json by setting
self.model.generation_config before any generate() call.
Input contract:
The caller (pipeline.py via prompt_builder.py) sends a fully-formatted Phi-3 prompt string.
This handler does NOT apply any chat template — prompt arrives ready to tokenize.
{"inputs": "<|system|>...<|end|>\n<|user|>...<|end|>\n<|assistant|>\n"}
"""
# ── DynamicCache compatibility shim (transformers >= 4.38) ──────────────────
# Must be first — before any other transformers import.
import transformers.cache_utils as _cu
if not hasattr(_cu.DynamicCache, "get_max_length"):
_cu.DynamicCache.get_max_length = lambda self: None
from transformers import DynamicCache
if not hasattr(DynamicCache, "get_max_length"):
DynamicCache.get_max_length = lambda self: None
# ────────────────────────────────────────────────────────────────────────────
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
class EndpointHandler:
def __init__(self, path=""):
# ── Tokenizer ────────────────────────────────────────────────────
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True,
)
# ── Model ────────────────────────────────────────────────────────
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16, # matches config.json torch_dtype
device_map="auto",
trust_remote_code=True,
attn_implementation="eager", # avoids flash-attn dependency
)
self.model.eval()
# ── Override generation_config.json ──────────────────────────────
# generation_config.json in the repo has eos_token_id=[32000, 32001, 32007].
# Token 32001 is <|assistant|>, which appears in every input prompt.
# This causes generate() to stop at token 0 — empty output.
# We override it here so model.generate() never reads the repo file.
self.model.generation_config = GenerationConfig(
do_sample=False, # greedy — matches SFT+DPO training
repetition_penalty=1.0,
no_repeat_ngram_size=0, # permanently disabled
eos_token_id=32007, # <|end|> only — turn terminator
pad_token_id=32000, # <|endoftext|>
bos_token_id=1,
)
def __call__(self, data: dict) -> list:
# ── Input: fully-formatted prompt string from prompt_builder.py ──
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
max_new_tokens = int(parameters.get("max_new_tokens", 350))
repetition_penalty = float(parameters.get("repetition_penalty", 1.15))
# ── Tokenize — prompt already contains all special tokens ─────────
tokenized = self.tokenizer(
inputs,
return_tensors="pt",
truncation=True,
max_length=3500, # leaves 596-token headroom within 4096
add_special_tokens=False, # prompt_builder adds
).to(self.model.device)
input_length = tokenized["input_ids"].shape[1]
# ── Generate ──────────────────────────────────────────────────────
# generation_config on the model is already overridden in __init__.
# kwargs here take final precedence for per-request overrides.
with torch.inference_mode():
output = self.model.generate(
**tokenized,
max_new_tokens=max_new_tokens,
repetition_penalty=repetition_penalty,
do_sample=False,
no_repeat_ngram_size=0,
eos_token_id=32007, # <|end|> — confirmed turn terminator
pad_token_id=32000,
)
# ── Decode new tokens only ────────────────────────────────────────
new_tokens = output[0][input_length:]
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
return [{"generated_text": generated_text}]