初始化项目,由ModelHub XC社区提供模型
Model: SriRamanaAtmic/AtmicIntelv1 Source: Original Platform
This commit is contained in:
111
handler.py
Normal file
111
handler.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
handler.py — HuggingFace Inference Endpoint handler for SriRamanaAtmic/AtmicIntelv1
|
||||
Compatible with transformers==4.51.3 (matches model's transformers_version in config.json).
|
||||
|
||||
Generation parameters (from Section A4 of technical review — do not change):
|
||||
do_sample = False (greedy decoding — matches SFT + DPO training exactly)
|
||||
max_new_tokens = 350
|
||||
repetition_penalty = 1.0 (sole repetition control)
|
||||
no_repeat_ngram_size = 0 (PERMANENTLY DISABLED — hard-coded, not overridable via API)
|
||||
temperature / top_p (REMOVED — inactive under greedy decoding)
|
||||
|
||||
Token IDs (from added_tokens.json — verified):
|
||||
<|endoftext|> = 32000 (pad_token_id)
|
||||
<|assistant|> = 32001 (appears in INPUT prompt — must NEVER be eos_token_id)
|
||||
<|end|> = 32007 (turn terminator — correct eos for generation)
|
||||
|
||||
Critical: generation_config.json in the repo contains eos_token_id=[32000, 32001, 32007].
|
||||
Token 32001 (<|assistant|>) is present in every input prompt, causing generation to stop
|
||||
at token 0. This handler explicitly overrides generation_config.json by setting
|
||||
self.model.generation_config before any generate() call.
|
||||
|
||||
Input contract:
|
||||
The caller (pipeline.py via prompt_builder.py) sends a fully-formatted Phi-3 prompt string.
|
||||
This handler does NOT apply any chat template — prompt arrives ready to tokenize.
|
||||
{"inputs": "<|system|>...<|end|>\n<|user|>...<|end|>\n<|assistant|>\n"}
|
||||
"""
|
||||
|
||||
# ── DynamicCache compatibility shim (transformers >= 4.38) ──────────────────
|
||||
# Must be first — before any other transformers import.
|
||||
import transformers.cache_utils as _cu
|
||||
if not hasattr(_cu.DynamicCache, "get_max_length"):
|
||||
_cu.DynamicCache.get_max_length = lambda self: None
|
||||
|
||||
from transformers import DynamicCache
|
||||
if not hasattr(DynamicCache, "get_max_length"):
|
||||
DynamicCache.get_max_length = lambda self: None
|
||||
# ────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
||||
import torch
|
||||
|
||||
|
||||
class EndpointHandler:
|
||||
def __init__(self, path=""):
|
||||
# ── Tokenizer ────────────────────────────────────────────────────
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# ── Model ────────────────────────────────────────────────────────
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
path,
|
||||
torch_dtype=torch.bfloat16, # matches config.json torch_dtype
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager", # avoids flash-attn dependency
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
# ── Override generation_config.json ──────────────────────────────
|
||||
# generation_config.json in the repo has eos_token_id=[32000, 32001, 32007].
|
||||
# Token 32001 is <|assistant|>, which appears in every input prompt.
|
||||
# This causes generate() to stop at token 0 — empty output.
|
||||
# We override it here so model.generate() never reads the repo file.
|
||||
self.model.generation_config = GenerationConfig(
|
||||
do_sample=False, # greedy — matches SFT+DPO training
|
||||
repetition_penalty=1.0,
|
||||
no_repeat_ngram_size=0, # permanently disabled
|
||||
eos_token_id=32007, # <|end|> only — turn terminator
|
||||
pad_token_id=32000, # <|endoftext|>
|
||||
bos_token_id=1,
|
||||
)
|
||||
|
||||
def __call__(self, data: dict) -> list:
|
||||
# ── Input: fully-formatted prompt string from prompt_builder.py ──
|
||||
inputs = data.get("inputs", "")
|
||||
parameters = data.get("parameters", {})
|
||||
|
||||
max_new_tokens = int(parameters.get("max_new_tokens", 350))
|
||||
repetition_penalty = float(parameters.get("repetition_penalty", 1.15))
|
||||
|
||||
# ── Tokenize — prompt already contains all special tokens ─────────
|
||||
tokenized = self.tokenizer(
|
||||
inputs,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=3500, # leaves 596-token headroom within 4096
|
||||
add_special_tokens=False, # prompt_builder adds
|
||||
).to(self.model.device)
|
||||
|
||||
input_length = tokenized["input_ids"].shape[1]
|
||||
|
||||
# ── Generate ──────────────────────────────────────────────────────
|
||||
# generation_config on the model is already overridden in __init__.
|
||||
# kwargs here take final precedence for per-request overrides.
|
||||
with torch.inference_mode():
|
||||
output = self.model.generate(
|
||||
**tokenized,
|
||||
max_new_tokens=max_new_tokens,
|
||||
repetition_penalty=repetition_penalty,
|
||||
do_sample=False,
|
||||
no_repeat_ngram_size=0,
|
||||
eos_token_id=32007, # <|end|> — confirmed turn terminator
|
||||
pad_token_id=32000,
|
||||
)
|
||||
|
||||
# ── Decode new tokens only ────────────────────────────────────────
|
||||
new_tokens = output[0][input_length:]
|
||||
generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||
return [{"generated_text": generated_text}]
|
||||
Reference in New Issue
Block a user