139 lines
4.0 KiB
Python
139 lines
4.0 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
import os
|
|
|
|
import argparse
|
|
import json
|
|
|
|
from inspect_ai.log._log import EvalLog, EvalMetric, EvalSample
|
|
from inspect_ai import eval as inspect_eval # type: ignore # noqa: E402
|
|
from inspect_ai.util._display import init_display_type # noqa: E402
|
|
|
|
import inspect_evals.gsm8k # noqa: F401, E402 (registers task definitions)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Run Inspect AI eval without banners.")
|
|
parser.add_argument(
|
|
"--model-path",
|
|
type=str,
|
|
default="final_model",
|
|
help="Path to the Hugging Face model (directory or model identifier).",
|
|
)
|
|
# this is a good limit for this task, just keep it like that (or use less in case you want faster tests)
|
|
parser.add_argument(
|
|
"--limit",
|
|
type=int,
|
|
default=150,
|
|
help="Optional limit for number of samples to evaluate.",
|
|
)
|
|
parser.add_argument(
|
|
'--json-output-file',
|
|
type=str,
|
|
default=None,
|
|
help="Optional path to output the metrics as a seperate JSON file.",
|
|
)
|
|
parser.add_argument(
|
|
'--templates-dir',
|
|
type=str,
|
|
default="templates/",
|
|
)
|
|
# You can adjust --max-connections if you want faster tests and don't receive errors (or if you have issues with vllm, try lowering this value)
|
|
parser.add_argument(
|
|
"--max-connections",
|
|
type=int,
|
|
default=2,
|
|
)
|
|
parser.add_argument(
|
|
"--max-tokens",
|
|
type=int,
|
|
default=4000,
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-memory-utilization",
|
|
type=float,
|
|
default=0.3,
|
|
)
|
|
return parser.parse_args()
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
|
|
init_display_type("plain")
|
|
|
|
other_kwargs = {}
|
|
if (args.limit is not None) and (args.limit != -1):
|
|
other_kwargs["limit"] = args.limit
|
|
|
|
task = "inspect_evals/gsm8k"
|
|
model_args = {
|
|
'gpu_memory_utilization': args.gpu_memory_utilization,
|
|
}
|
|
model_args.update(template_kwargs(args))
|
|
|
|
eval_out = inspect_eval(
|
|
task,
|
|
model=f"vllm/{args.model_path}",
|
|
model_args=model_args,
|
|
score_display=False,
|
|
log_realtime=False,
|
|
log_format='json',
|
|
timeout=18000000,
|
|
attempt_timeout=18000000,
|
|
max_tokens=args.max_tokens,
|
|
max_connections=args.max_connections,
|
|
**other_kwargs,
|
|
)
|
|
|
|
if args.json_output_file is not None:
|
|
assert len(eval_out) == 1, eval_out
|
|
assert len(eval_out[0].results.scores) == 1, eval_out[0].results.scores
|
|
metrics = {}
|
|
for k, v in eval_out[0].results.scores[0].metrics.items():
|
|
metrics[k] = v.value
|
|
|
|
with open(args.json_output_file, 'w') as f:
|
|
json.dump(metrics, f, indent=2)
|
|
|
|
def model_type(args) -> str:
|
|
if 'qwen' in args.model_path.lower():
|
|
return 'qwen'
|
|
if 'llama' in args.model_path.lower():
|
|
return 'llama'
|
|
if 'gemma' in args.model_path.lower():
|
|
return 'gemma'
|
|
if 'smollm' in args.model_path.lower():
|
|
return 'smollm'
|
|
|
|
with open(os.path.join(args.model_path, "config.json"), 'r') as f:
|
|
config = json.load(f)
|
|
architecture = config['architectures'][0].lower()
|
|
if 'gemma' in architecture:
|
|
return 'gemma'
|
|
if 'llama' in architecture:
|
|
return 'llama'
|
|
if 'qwen' in architecture:
|
|
return 'qwen'
|
|
if 'smollm' in architecture:
|
|
return 'smollm'
|
|
raise ValueError(architecture)
|
|
|
|
def template_kwargs(args) -> dict:
|
|
model_type_str = model_type(args)
|
|
if model_type_str == 'qwen':
|
|
template = 'qwen3.jinja'
|
|
elif model_type_str == 'llama':
|
|
template = 'llama3.jinja'
|
|
elif model_type_str == 'gemma':
|
|
template = 'gemma3.jinja'
|
|
elif model_type_str == 'smollm':
|
|
template = 'smollm.jinja'
|
|
else:
|
|
raise ValueError(model_type_str)
|
|
return {
|
|
'chat_template': os.path.join(args.templates_dir, template)
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
main()
|