260 lines
9.2 KiB
Python
260 lines
9.2 KiB
Python
"""
|
|
DLM-NL2JSON-4B — Evaluation Script (Simplified)
|
|
|
|
Evaluates the model on the provided test set using an OpenAI-compatible API endpoint.
|
|
Measures per-category exact match accuracy and average latency.
|
|
|
|
Usage:
|
|
# Against vLLM / TensorRT-LLM served model
|
|
python eval_example.py \
|
|
--data test_data_lite_200.jsonl \
|
|
--base-url http://your-server:8006/v1 \
|
|
--model qwen3_4b_6th_norag \
|
|
--api-key token-abc123 \
|
|
--disable-thinking
|
|
|
|
# Against OpenAI API (GPT-4o baseline)
|
|
export OPENAI_API_KEY="sk-..."
|
|
python eval_example.py \
|
|
--data test_data_lite_200.jsonl \
|
|
--model gpt-4o
|
|
"""
|
|
|
|
import json, re, time, argparse, os
|
|
from collections import Counter
|
|
from typing import Dict, Any, List
|
|
|
|
# ── Prompts ──────────────────────────────────────────────
|
|
# Import from prompts.py (must be in the same directory)
|
|
from prompts import (
|
|
SYS_CSM_DEFAULT,
|
|
SYS_CREDIT_DEFAULT,
|
|
SYS_GIS_DEFAULT,
|
|
SYS_ALP_DEFAULT,
|
|
SYS_CPI_DEFAULT,
|
|
)
|
|
|
|
# ── Category → (special_token, system_prompt) ────────────
|
|
TASK_MAP = {
|
|
0: ("<TASK_ALP>", SYS_ALP_DEFAULT), # ALP-A (pattern)
|
|
1: ("<TASK_ALP>", SYS_ALP_DEFAULT), # ALP-B (flow)
|
|
2: ("<TASK_CSM>", SYS_CSM_DEFAULT), # CSM (consumer spending)
|
|
3: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), # CREDIT-Income
|
|
4: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), # CREDIT-Spending
|
|
5: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), # CREDIT-Loan/Default
|
|
6: ("<TASK_CPI>", SYS_CPI_DEFAULT), # CPI (business status)
|
|
9: ("<TASK_GIS>", SYS_GIS_DEFAULT), # GIS-Inflow
|
|
10: ("<TASK_GIS>", SYS_GIS_DEFAULT), # GIS-Outflow
|
|
11: ("<TASK_GIS>", SYS_GIS_DEFAULT), # GIS-Consumption
|
|
}
|
|
|
|
CAT_NAMES = {
|
|
0: "ALP-A(ptrn)", 1: "ALP-B(flow)", 2: "CSM",
|
|
3: "CREDIT-Income", 4: "CREDIT-Spending", 5: "CREDIT-Loan",
|
|
6: "CPI", 9: "GIS-Inflow", 10: "GIS-Outflow", 11: "GIS-Consumption",
|
|
}
|
|
|
|
# ── Required keys per category (for comparison) ─────────
|
|
REQUIRED_KEYS = {
|
|
0: ["base_ym", "region_nm", "ptrn", "sex_cd", "age_cd", "category"],
|
|
1: ["base_ym", "region_nm", "flow_cd", "sex_cd", "age_cd", "category"],
|
|
2: ["base_ym", "region_nm", "industry_select", "sex_cd", "age_cd", "category"],
|
|
3: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"],
|
|
4: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"],
|
|
5: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"],
|
|
6: ["base_ym", "region_nm", "bzc_cd", "cp_cd", "enp_cd", "category"],
|
|
9: ["region_nm", "base_ym", "region_count", "category"],
|
|
10: ["region_nm", "base_ym", "region_count", "category"],
|
|
11: ["region_nm", "base_ym", "industry_category", "category"],
|
|
}
|
|
|
|
|
|
# ── Normalization helpers ────────────────────────────────
|
|
def norm_int_list(v):
|
|
if not isinstance(v, list):
|
|
return v
|
|
out = []
|
|
for x in v:
|
|
try:
|
|
out.append(int(float(str(x).strip())))
|
|
except Exception:
|
|
continue
|
|
return sorted(set(out))
|
|
|
|
|
|
def norm_dict_of_lists(d):
|
|
"""Normalize industry_select or bzc_cd: {str_key: [int, ...]}"""
|
|
if not isinstance(d, dict):
|
|
return d
|
|
return {str(k).upper() if len(str(k)) == 1 and str(k).isalpha() else str(k):
|
|
norm_int_list(arr) if isinstance(arr, list) else arr
|
|
for k, arr in d.items()}
|
|
|
|
|
|
def normalize(obj: Dict[str, Any], cat: int) -> Dict[str, Any]:
|
|
"""Normalize prediction/gold for fair comparison (summary excluded)."""
|
|
o = dict(obj)
|
|
o.pop("summary", None)
|
|
|
|
for k in ["base_ym", "region_count", "category"]:
|
|
if k in o and isinstance(o[k], str):
|
|
try:
|
|
o[k] = int(o[k])
|
|
except ValueError:
|
|
pass
|
|
|
|
for k in ["sex_cd", "age_cd", "job_cd", "perc_cd", "ptrn",
|
|
"industry_category", "cp_cd", "enp_cd"]:
|
|
if k in o:
|
|
o[k] = norm_int_list(o[k])
|
|
|
|
if "flow_cd" in o and isinstance(o["flow_cd"], list):
|
|
o["flow_cd"] = norm_int_list(o["flow_cd"])
|
|
|
|
for k in ["industry_select", "bzc_cd"]:
|
|
if k in o:
|
|
o[k] = norm_dict_of_lists(o[k])
|
|
|
|
if "region_count" in o:
|
|
try:
|
|
o["region_count"] = max(1, min(10, int(o["region_count"])))
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
return o
|
|
|
|
|
|
def extract_first_json(text: str):
|
|
start = text.find("{")
|
|
if start == -1:
|
|
return None
|
|
depth = 0
|
|
for i in range(start, len(text)):
|
|
if text[i] == "{":
|
|
depth += 1
|
|
elif text[i] == "}":
|
|
depth -= 1
|
|
if depth == 0:
|
|
return text[start:i + 1]
|
|
return None
|
|
|
|
|
|
def compare(pred: Dict, gold: Dict, cat: int):
|
|
req = REQUIRED_KEYS.get(cat, [])
|
|
diff = {}
|
|
for k in req:
|
|
if pred.get(k, "<MISSING>") != gold.get(k, "<MISSING>"):
|
|
diff[k] = {"pred": pred.get(k), "gold": gold.get(k)}
|
|
return len(diff) == 0, diff
|
|
|
|
|
|
# ── Main ─────────────────────────────────────────────────
|
|
def main():
|
|
ap = argparse.ArgumentParser(description="DLM-NL2JSON-4B Evaluation")
|
|
ap.add_argument("--data", required=True, help="Test JSONL file path")
|
|
ap.add_argument("--base-url", default=None, help="OpenAI-compatible base URL")
|
|
ap.add_argument("--model", required=True, help="Model name")
|
|
ap.add_argument("--api-key", default=os.environ.get("OPENAI_API_KEY", ""), help="API key")
|
|
ap.add_argument("--disable-thinking", action="store_true",
|
|
help="Pass chat_template_kwargs to disable Qwen3 thinking mode")
|
|
ap.add_argument("--max-tokens", type=int, default=512)
|
|
ap.add_argument("--per-cat", type=int, default=999, help="Max samples per category")
|
|
args = ap.parse_args()
|
|
|
|
import openai
|
|
client = openai.OpenAI(
|
|
base_url=args.base_url or None,
|
|
api_key=args.api_key or "dummy",
|
|
timeout=60.0,
|
|
)
|
|
|
|
# Load test data
|
|
with open(args.data, encoding="utf-8") as f:
|
|
raw = [json.loads(line) for line in f]
|
|
|
|
# Group by category and sample
|
|
from collections import defaultdict
|
|
by_cat = defaultdict(list)
|
|
for item in raw:
|
|
out = item["output"] if isinstance(item["output"], dict) else json.loads(item["output"])
|
|
cat = out["category"]
|
|
by_cat[cat].append({"input": item["input"], "gold": out})
|
|
|
|
samples = []
|
|
for cat in sorted(by_cat):
|
|
items = by_cat[cat][:args.per_cat]
|
|
samples.extend([(cat, ex) for ex in items])
|
|
|
|
print(f"[INFO] Evaluating {len(samples)} samples across {len(by_cat)} categories\n")
|
|
|
|
# Evaluate
|
|
ok_counts, total_counts = Counter(), Counter()
|
|
latency_sums = Counter()
|
|
|
|
for idx, (cat, ex) in enumerate(samples, 1):
|
|
user_in = ex["input"].strip()
|
|
gold_norm = normalize(ex["gold"], cat)
|
|
|
|
tag, sys_prompt = TASK_MAP[cat]
|
|
messages = [
|
|
{"role": "system", "content": sys_prompt},
|
|
{"role": "user", "content": f"{tag}\n{user_in}"},
|
|
]
|
|
|
|
kwargs = dict(model=args.model, messages=messages,
|
|
max_tokens=args.max_tokens, temperature=0.0)
|
|
if args.disable_thinking:
|
|
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}}
|
|
|
|
t0 = time.perf_counter()
|
|
try:
|
|
resp = client.chat.completions.create(**kwargs)
|
|
gen = resp.choices[0].message.content
|
|
except Exception as e:
|
|
dt = time.perf_counter() - t0
|
|
total_counts[cat] += 1
|
|
latency_sums[cat] += dt
|
|
print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | ERROR: {e}")
|
|
continue
|
|
|
|
dt = time.perf_counter() - t0
|
|
total_counts[cat] += 1
|
|
latency_sums[cat] += dt
|
|
|
|
json_str = extract_first_json(gen) or gen.strip()
|
|
try:
|
|
pred_obj = json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | PARSE_FAIL | {dt:.2f}s")
|
|
continue
|
|
|
|
pred_norm = normalize(pred_obj, cat)
|
|
ok, diff = compare(pred_norm, gold_norm, cat)
|
|
if ok:
|
|
ok_counts[cat] += 1
|
|
|
|
status = "OK" if ok else f"FAIL {list(diff.keys())}"
|
|
print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | {status} | {dt:.2f}s")
|
|
|
|
# Summary
|
|
print("\n" + "=" * 50)
|
|
print("EVALUATION SUMMARY")
|
|
print("=" * 50)
|
|
total_ok = total_all = 0
|
|
for c in sorted(total_counts):
|
|
ok = ok_counts[c]
|
|
tot = total_counts[c]
|
|
acc = ok / tot if tot else 0
|
|
avg_lat = latency_sums[c] / tot if tot else 0
|
|
total_ok += ok
|
|
total_all += tot
|
|
print(f" {CAT_NAMES.get(c, c):20s}: {ok:4d}/{tot:4d} acc={acc:.1%} avg={avg_lat:.3f}s")
|
|
|
|
overall_acc = total_ok / total_all if total_all else 0
|
|
overall_lat = sum(latency_sums.values()) / total_all if total_all else 0
|
|
print(f" {'OVERALL':20s}: {total_ok:4d}/{total_all:4d} acc={overall_acc:.1%} avg={overall_lat:.3f}s")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|