初始化项目,由ModelHub XC社区提供模型
Model: dataslab/DLM-NL2JSON-4B Source: Original Platform
This commit is contained in:
259
eval/eval_example.py
Normal file
259
eval/eval_example.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
DLM-NL2JSON-4B — Evaluation Script (Simplified)
|
||||
|
||||
Evaluates the model on the provided test set using an OpenAI-compatible API endpoint.
|
||||
Measures per-category exact match accuracy and average latency.
|
||||
|
||||
Usage:
|
||||
# Against vLLM / TensorRT-LLM served model
|
||||
python eval_example.py \
|
||||
--data test_data_lite_200.jsonl \
|
||||
--base-url http://your-server:8006/v1 \
|
||||
--model qwen3_4b_6th_norag \
|
||||
--api-key token-abc123 \
|
||||
--disable-thinking
|
||||
|
||||
# Against OpenAI API (GPT-4o baseline)
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
python eval_example.py \
|
||||
--data test_data_lite_200.jsonl \
|
||||
--model gpt-4o
|
||||
"""
|
||||
|
||||
import json, re, time, argparse, os
|
||||
from collections import Counter
|
||||
from typing import Dict, Any, List
|
||||
|
||||
# ── Prompts ──────────────────────────────────────────────
|
||||
# Import from prompts.py (must be in the same directory)
|
||||
from prompts import (
|
||||
SYS_CSM_DEFAULT,
|
||||
SYS_CREDIT_DEFAULT,
|
||||
SYS_GIS_DEFAULT,
|
||||
SYS_ALP_DEFAULT,
|
||||
SYS_CPI_DEFAULT,
|
||||
)
|
||||
|
||||
# ── Category → (special_token, system_prompt) ────────────
|
||||
TASK_MAP = {
|
||||
0: ("<TASK_ALP>", SYS_ALP_DEFAULT), # ALP-A (pattern)
|
||||
1: ("<TASK_ALP>", SYS_ALP_DEFAULT), # ALP-B (flow)
|
||||
2: ("<TASK_CSM>", SYS_CSM_DEFAULT), # CSM (consumer spending)
|
||||
3: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), # CREDIT-Income
|
||||
4: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), # CREDIT-Spending
|
||||
5: ("<TASK_CREDIT>", SYS_CREDIT_DEFAULT), # CREDIT-Loan/Default
|
||||
6: ("<TASK_CPI>", SYS_CPI_DEFAULT), # CPI (business status)
|
||||
9: ("<TASK_GIS>", SYS_GIS_DEFAULT), # GIS-Inflow
|
||||
10: ("<TASK_GIS>", SYS_GIS_DEFAULT), # GIS-Outflow
|
||||
11: ("<TASK_GIS>", SYS_GIS_DEFAULT), # GIS-Consumption
|
||||
}
|
||||
|
||||
CAT_NAMES = {
|
||||
0: "ALP-A(ptrn)", 1: "ALP-B(flow)", 2: "CSM",
|
||||
3: "CREDIT-Income", 4: "CREDIT-Spending", 5: "CREDIT-Loan",
|
||||
6: "CPI", 9: "GIS-Inflow", 10: "GIS-Outflow", 11: "GIS-Consumption",
|
||||
}
|
||||
|
||||
# ── Required keys per category (for comparison) ─────────
|
||||
REQUIRED_KEYS = {
|
||||
0: ["base_ym", "region_nm", "ptrn", "sex_cd", "age_cd", "category"],
|
||||
1: ["base_ym", "region_nm", "flow_cd", "sex_cd", "age_cd", "category"],
|
||||
2: ["base_ym", "region_nm", "industry_select", "sex_cd", "age_cd", "category"],
|
||||
3: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"],
|
||||
4: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"],
|
||||
5: ["base_ym", "region_nm", "job_cd", "perc_cd", "sex_cd", "age_cd", "category"],
|
||||
6: ["base_ym", "region_nm", "bzc_cd", "cp_cd", "enp_cd", "category"],
|
||||
9: ["region_nm", "base_ym", "region_count", "category"],
|
||||
10: ["region_nm", "base_ym", "region_count", "category"],
|
||||
11: ["region_nm", "base_ym", "industry_category", "category"],
|
||||
}
|
||||
|
||||
|
||||
# ── Normalization helpers ────────────────────────────────
|
||||
def norm_int_list(v):
|
||||
if not isinstance(v, list):
|
||||
return v
|
||||
out = []
|
||||
for x in v:
|
||||
try:
|
||||
out.append(int(float(str(x).strip())))
|
||||
except Exception:
|
||||
continue
|
||||
return sorted(set(out))
|
||||
|
||||
|
||||
def norm_dict_of_lists(d):
|
||||
"""Normalize industry_select or bzc_cd: {str_key: [int, ...]}"""
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
return {str(k).upper() if len(str(k)) == 1 and str(k).isalpha() else str(k):
|
||||
norm_int_list(arr) if isinstance(arr, list) else arr
|
||||
for k, arr in d.items()}
|
||||
|
||||
|
||||
def normalize(obj: Dict[str, Any], cat: int) -> Dict[str, Any]:
|
||||
"""Normalize prediction/gold for fair comparison (summary excluded)."""
|
||||
o = dict(obj)
|
||||
o.pop("summary", None)
|
||||
|
||||
for k in ["base_ym", "region_count", "category"]:
|
||||
if k in o and isinstance(o[k], str):
|
||||
try:
|
||||
o[k] = int(o[k])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
for k in ["sex_cd", "age_cd", "job_cd", "perc_cd", "ptrn",
|
||||
"industry_category", "cp_cd", "enp_cd"]:
|
||||
if k in o:
|
||||
o[k] = norm_int_list(o[k])
|
||||
|
||||
if "flow_cd" in o and isinstance(o["flow_cd"], list):
|
||||
o["flow_cd"] = norm_int_list(o["flow_cd"])
|
||||
|
||||
for k in ["industry_select", "bzc_cd"]:
|
||||
if k in o:
|
||||
o[k] = norm_dict_of_lists(o[k])
|
||||
|
||||
if "region_count" in o:
|
||||
try:
|
||||
o["region_count"] = max(1, min(10, int(o["region_count"])))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return o
|
||||
|
||||
|
||||
def extract_first_json(text: str):
|
||||
start = text.find("{")
|
||||
if start == -1:
|
||||
return None
|
||||
depth = 0
|
||||
for i in range(start, len(text)):
|
||||
if text[i] == "{":
|
||||
depth += 1
|
||||
elif text[i] == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return text[start:i + 1]
|
||||
return None
|
||||
|
||||
|
||||
def compare(pred: Dict, gold: Dict, cat: int):
|
||||
req = REQUIRED_KEYS.get(cat, [])
|
||||
diff = {}
|
||||
for k in req:
|
||||
if pred.get(k, "<MISSING>") != gold.get(k, "<MISSING>"):
|
||||
diff[k] = {"pred": pred.get(k), "gold": gold.get(k)}
|
||||
return len(diff) == 0, diff
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────
|
||||
def main():
|
||||
ap = argparse.ArgumentParser(description="DLM-NL2JSON-4B Evaluation")
|
||||
ap.add_argument("--data", required=True, help="Test JSONL file path")
|
||||
ap.add_argument("--base-url", default=None, help="OpenAI-compatible base URL")
|
||||
ap.add_argument("--model", required=True, help="Model name")
|
||||
ap.add_argument("--api-key", default=os.environ.get("OPENAI_API_KEY", ""), help="API key")
|
||||
ap.add_argument("--disable-thinking", action="store_true",
|
||||
help="Pass chat_template_kwargs to disable Qwen3 thinking mode")
|
||||
ap.add_argument("--max-tokens", type=int, default=512)
|
||||
ap.add_argument("--per-cat", type=int, default=999, help="Max samples per category")
|
||||
args = ap.parse_args()
|
||||
|
||||
import openai
|
||||
client = openai.OpenAI(
|
||||
base_url=args.base_url or None,
|
||||
api_key=args.api_key or "dummy",
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
# Load test data
|
||||
with open(args.data, encoding="utf-8") as f:
|
||||
raw = [json.loads(line) for line in f]
|
||||
|
||||
# Group by category and sample
|
||||
from collections import defaultdict
|
||||
by_cat = defaultdict(list)
|
||||
for item in raw:
|
||||
out = item["output"] if isinstance(item["output"], dict) else json.loads(item["output"])
|
||||
cat = out["category"]
|
||||
by_cat[cat].append({"input": item["input"], "gold": out})
|
||||
|
||||
samples = []
|
||||
for cat in sorted(by_cat):
|
||||
items = by_cat[cat][:args.per_cat]
|
||||
samples.extend([(cat, ex) for ex in items])
|
||||
|
||||
print(f"[INFO] Evaluating {len(samples)} samples across {len(by_cat)} categories\n")
|
||||
|
||||
# Evaluate
|
||||
ok_counts, total_counts = Counter(), Counter()
|
||||
latency_sums = Counter()
|
||||
|
||||
for idx, (cat, ex) in enumerate(samples, 1):
|
||||
user_in = ex["input"].strip()
|
||||
gold_norm = normalize(ex["gold"], cat)
|
||||
|
||||
tag, sys_prompt = TASK_MAP[cat]
|
||||
messages = [
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{"role": "user", "content": f"{tag}\n{user_in}"},
|
||||
]
|
||||
|
||||
kwargs = dict(model=args.model, messages=messages,
|
||||
max_tokens=args.max_tokens, temperature=0.0)
|
||||
if args.disable_thinking:
|
||||
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": False}}
|
||||
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
resp = client.chat.completions.create(**kwargs)
|
||||
gen = resp.choices[0].message.content
|
||||
except Exception as e:
|
||||
dt = time.perf_counter() - t0
|
||||
total_counts[cat] += 1
|
||||
latency_sums[cat] += dt
|
||||
print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | ERROR: {e}")
|
||||
continue
|
||||
|
||||
dt = time.perf_counter() - t0
|
||||
total_counts[cat] += 1
|
||||
latency_sums[cat] += dt
|
||||
|
||||
json_str = extract_first_json(gen) or gen.strip()
|
||||
try:
|
||||
pred_obj = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | PARSE_FAIL | {dt:.2f}s")
|
||||
continue
|
||||
|
||||
pred_norm = normalize(pred_obj, cat)
|
||||
ok, diff = compare(pred_norm, gold_norm, cat)
|
||||
if ok:
|
||||
ok_counts[cat] += 1
|
||||
|
||||
status = "OK" if ok else f"FAIL {list(diff.keys())}"
|
||||
print(f"[{idx:04d}] {CAT_NAMES.get(cat, cat)} | {status} | {dt:.2f}s")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 50)
|
||||
print("EVALUATION SUMMARY")
|
||||
print("=" * 50)
|
||||
total_ok = total_all = 0
|
||||
for c in sorted(total_counts):
|
||||
ok = ok_counts[c]
|
||||
tot = total_counts[c]
|
||||
acc = ok / tot if tot else 0
|
||||
avg_lat = latency_sums[c] / tot if tot else 0
|
||||
total_ok += ok
|
||||
total_all += tot
|
||||
print(f" {CAT_NAMES.get(c, c):20s}: {ok:4d}/{tot:4d} acc={acc:.1%} avg={avg_lat:.3f}s")
|
||||
|
||||
overall_acc = total_ok / total_all if total_all else 0
|
||||
overall_lat = sum(latency_sums.values()) / total_all if total_all else 0
|
||||
print(f" {'OVERALL':20s}: {total_ok:4d}/{total_all:4d} acc={overall_acc:.1%} avg={overall_lat:.3f}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user