111 lines
5.8 KiB
Python
111 lines
5.8 KiB
Python
|
|
"""
|
||
|
|
handler.py — HuggingFace Inference Endpoint handler for SriRamanaAtmic/AtmicIntelv1
|
||
|
|
Compatible with transformers==4.51.3 (matches model's transformers_version in config.json).
|
||
|
|
|
||
|
|
Generation parameters (from Section A4 of technical review — do not change):
|
||
|
|
do_sample = False (greedy decoding — matches SFT + DPO training exactly)
|
||
|
|
max_new_tokens = 350
|
||
|
|
repetition_penalty = 1.0 (sole repetition control)
|
||
|
|
no_repeat_ngram_size = 0 (PERMANENTLY DISABLED — hard-coded, not overridable via API)
|
||
|
|
temperature / top_p (REMOVED — inactive under greedy decoding)
|
||
|
|
|
||
|
|
Token IDs (from added_tokens.json — verified):
|
||
|
|
<|endoftext|> = 32000 (pad_token_id)
|
||
|
|
<|assistant|> = 32001 (appears in INPUT prompt — must NEVER be eos_token_id)
|
||
|
|
<|end|> = 32007 (turn terminator — correct eos for generation)
|
||
|
|
|
||
|
|
Critical: generation_config.json in the repo contains eos_token_id=[32000, 32001, 32007].
|
||
|
|
Token 32001 (<|assistant|>) is present in every input prompt, causing generation to stop
|
||
|
|
at token 0. This handler explicitly overrides generation_config.json by setting
|
||
|
|
self.model.generation_config before any generate() call.
|
||
|
|
|
||
|
|
Input contract:
|
||
|
|
The caller (pipeline.py via prompt_builder.py) sends a fully-formatted Phi-3 prompt string.
|
||
|
|
This handler does NOT apply any chat template — prompt arrives ready to tokenize.
|
||
|
|
{"inputs": "<|system|>...<|end|>\n<|user|>...<|end|>\n<|assistant|>\n"}
|
||
|
|
"""
|
||
|
|
|
||
|
|
# ── DynamicCache compatibility shim (transformers >= 4.38) ──────────────────
|
||
|
|
# Must be first — before any other transformers import.
|
||
|
|
import transformers.cache_utils as _cu
|
||
|
|
if not hasattr(_cu.DynamicCache, "get_max_length"):
|
||
|
|
_cu.DynamicCache.get_max_length = lambda self: None
|
||
|
|
|
||
|
|
from transformers import DynamicCache
|
||
|
|
if not hasattr(DynamicCache, "get_max_length"):
|
||
|
|
DynamicCache.get_max_length = lambda self: None
|
||
|
|
# ────────────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
||
|
|
import torch
|
||
|
|
|
||
|
|
|
||
|
|
class EndpointHandler:
|
||
|
|
def __init__(self, path=""):
|
||
|
|
# ── Tokenizer ────────────────────────────────────────────────────
|
||
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
|
|
path,
|
||
|
|
trust_remote_code=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# ── Model ────────────────────────────────────────────────────────
|
||
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
path,
|
||
|
|
torch_dtype=torch.bfloat16, # matches config.json torch_dtype
|
||
|
|
device_map="auto",
|
||
|
|
trust_remote_code=True,
|
||
|
|
attn_implementation="eager", # avoids flash-attn dependency
|
||
|
|
)
|
||
|
|
self.model.eval()
|
||
|
|
|
||
|
|
# ── Override generation_config.json ──────────────────────────────
|
||
|
|
# generation_config.json in the repo has eos_token_id=[32000, 32001, 32007].
|
||
|
|
# Token 32001 is <|assistant|>, which appears in every input prompt.
|
||
|
|
# This causes generate() to stop at token 0 — empty output.
|
||
|
|
# We override it here so model.generate() never reads the repo file.
|
||
|
|
self.model.generation_config = GenerationConfig(
|
||
|
|
do_sample=False, # greedy — matches SFT+DPO training
|
||
|
|
repetition_penalty=1.0,
|
||
|
|
no_repeat_ngram_size=0, # permanently disabled
|
||
|
|
eos_token_id=32007, # <|end|> only — turn terminator
|
||
|
|
pad_token_id=32000, # <|endoftext|>
|
||
|
|
bos_token_id=1,
|
||
|
|
)
|
||
|
|
|
||
|
|
def __call__(self, data: dict) -> list:
|
||
|
|
# ── Input: fully-formatted prompt string from prompt_builder.py ──
|
||
|
|
inputs = data.get("inputs", "")
|
||
|
|
parameters = data.get("parameters", {})
|
||
|
|
|
||
|
|
max_new_tokens = int(parameters.get("max_new_tokens", 350))
|
||
|
|
repetition_penalty = float(parameters.get("repetition_penalty", 1.15))
|
||
|
|
|
||
|
|
# ── Tokenize — prompt already contains all special tokens ─────────
|
||
|
|
tokenized = self.tokenizer(
|
||
|
|
inputs,
|
||
|
|
return_tensors="pt",
|
||
|
|
truncation=True,
|
||
|
|
max_length=3500, # leaves 596-token headroom within 4096
|
||
|
|
add_special_tokens=False, # prompt_builder adds
|
||
|
|
).to(self.model.device)
|
||
|
|
|
||
|
|
input_length = tokenized["input_ids"].shape[1]
|
||
|
|
|
||
|
|
# ── Generate ──────────────────────────────────────────────────────
|
||
|
|
# generation_config on the model is already overridden in __init__.
|
||
|
|
# kwargs here take final precedence for per-request overrides.
|
||
|
|
with torch.inference_mode():
|
||
|
|
output = self.model.generate(
|
||
|
|
**tokenized,
|
||
|
|
max_new_tokens=max_new_tokens,
|
||
|
|
repetition_penalty=repetition_penalty,
|
||
|
|
do_sample=False,
|
||
|
|
no_repeat_ngram_size=0,
|
||
|
|
eos_token_id=32007, # <|end|> — confirmed turn terminator
|
||
|
|
pad_token_id=32000,
|
||
|
|
)
|
||
|
|
|
||
|
|
# ── Decode new tokens only ────────────────────────────────────────
|
||
|
|
new_tokens = output[0][input_length:]
|
||
|
|
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||
|
|
return [{"generated_text": generated_text}]
|