Files
qwen3-1.7b-gsm8k-sft/scripts/evaluate.py

139 lines
4.0 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
from __future__ import annotations
import os
import argparse
import json
from inspect_ai.log._log import EvalLog, EvalMetric, EvalSample
from inspect_ai import eval as inspect_eval # type: ignore # noqa: E402
from inspect_ai.util._display import init_display_type # noqa: E402
import inspect_evals.gsm8k # noqa: F401, E402 (registers task definitions)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run Inspect AI eval without banners.")
parser.add_argument(
"--model-path",
type=str,
default="final_model",
help="Path to the Hugging Face model (directory or model identifier).",
)
# this is a good limit for this task, just keep it like that (or use less in case you want faster tests)
parser.add_argument(
"--limit",
type=int,
default=150,
help="Optional limit for number of samples to evaluate.",
)
parser.add_argument(
'--json-output-file',
type=str,
default=None,
help="Optional path to output the metrics as a seperate JSON file.",
)
parser.add_argument(
'--templates-dir',
type=str,
default="templates/",
)
# You can adjust --max-connections if you want faster tests and don't receive errors (or if you have issues with vllm, try lowering this value)
parser.add_argument(
"--max-connections",
type=int,
default=2,
)
parser.add_argument(
"--max-tokens",
type=int,
default=4000,
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.3,
)
return parser.parse_args()
def main() -> None:
args = parse_args()
init_display_type("plain")
other_kwargs = {}
if (args.limit is not None) and (args.limit != -1):
other_kwargs["limit"] = args.limit
task = "inspect_evals/gsm8k"
model_args = {
'gpu_memory_utilization': args.gpu_memory_utilization,
}
model_args.update(template_kwargs(args))
eval_out = inspect_eval(
task,
model=f"vllm/{args.model_path}",
model_args=model_args,
score_display=False,
log_realtime=False,
log_format='json',
timeout=18000000,
attempt_timeout=18000000,
max_tokens=args.max_tokens,
max_connections=args.max_connections,
**other_kwargs,
)
if args.json_output_file is not None:
assert len(eval_out) == 1, eval_out
assert len(eval_out[0].results.scores) == 1, eval_out[0].results.scores
metrics = {}
for k, v in eval_out[0].results.scores[0].metrics.items():
metrics[k] = v.value
with open(args.json_output_file, 'w') as f:
json.dump(metrics, f, indent=2)
def model_type(args) -> str:
if 'qwen' in args.model_path.lower():
return 'qwen'
if 'llama' in args.model_path.lower():
return 'llama'
if 'gemma' in args.model_path.lower():
return 'gemma'
if 'smollm' in args.model_path.lower():
return 'smollm'
with open(os.path.join(args.model_path, "config.json"), 'r') as f:
config = json.load(f)
architecture = config['architectures'][0].lower()
if 'gemma' in architecture:
return 'gemma'
if 'llama' in architecture:
return 'llama'
if 'qwen' in architecture:
return 'qwen'
if 'smollm' in architecture:
return 'smollm'
raise ValueError(architecture)
def template_kwargs(args) -> dict:
model_type_str = model_type(args)
if model_type_str == 'qwen':
template = 'qwen3.jinja'
elif model_type_str == 'llama':
template = 'llama3.jinja'
elif model_type_str == 'gemma':
template = 'gemma3.jinja'
elif model_type_str == 'smollm':
template = 'smollm.jinja'
else:
raise ValueError(model_type_str)
return {
'chat_template': os.path.join(args.templates_dir, template)
}
if __name__ == "__main__":
main()