97 lines
3.4 KiB
Python
97 lines
3.4 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any, Dict
|
|
|
|
if __name__ == "__main__":
|
|
# Get the user requests
|
|
parser = argparse.ArgumentParser(
|
|
"Collect results from a given batch of distributed results"
|
|
)
|
|
parser.add_argument("-ck", "--checkpoint_path", required=True)
|
|
args = parser.parse_args()
|
|
|
|
logging.getLogger().setLevel(logging.INFO)
|
|
|
|
# Go through all the data in the given repo, try to find the end results
|
|
root = Path(args.checkpoint_path)
|
|
|
|
# - list all the mechanisms being benchmarked
|
|
results: Dict[str, Any] = {}
|
|
|
|
for attention in filter(lambda x: x.is_dir(), root.iterdir()):
|
|
logging.info(f"\nFound results for {attention.stem}")
|
|
task_jsons = attention.glob("*/test_eval_summary.json")
|
|
results[attention.stem] = {}
|
|
|
|
for task in task_jsons:
|
|
task_name = task.stem.split("__")[0]
|
|
logging.info(f"Logs found for task: {task_name}")
|
|
results[attention.stem][task_name] = -1
|
|
found_result = False
|
|
|
|
# - collect the individual results
|
|
with open(task, "r") as result_file:
|
|
dct = json.load(result_file)
|
|
if "test_accu_mean" in dct:
|
|
found_result = True
|
|
results[attention.stem][task_name] = dct["test_accu_mean"]
|
|
|
|
logging.info(
|
|
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
|
|
f"{results[attention.stem][task_name]}"
|
|
)
|
|
else:
|
|
break
|
|
|
|
# - report an error if no result was found
|
|
if not found_result:
|
|
ERR_TAIL = 30
|
|
|
|
logging.warning(
|
|
f"No result found for {task_name}, showing the error log in {task.parent}"
|
|
)
|
|
err_log = Path(task.parent).glob("*.err")
|
|
print("*****************************************************")
|
|
with open(next(err_log), "r") as err_file:
|
|
for i, line in enumerate(reversed(err_file.readlines())):
|
|
print(line, end="")
|
|
if i > ERR_TAIL:
|
|
break
|
|
print("*****************************************************")
|
|
|
|
logging.info(f"\nCollected results: {json.dumps(results, indent=2)}")
|
|
|
|
# - reduction: compute the average
|
|
tasks = set(t for v in results.values() for t in v.keys())
|
|
# -- fill in the possible gaps
|
|
for att in results.keys():
|
|
for t in tasks:
|
|
if t not in results[att].keys():
|
|
results[att][t] = 0.0
|
|
|
|
# -- add the average value
|
|
for att in results.keys():
|
|
results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2)
|
|
|
|
# - Format as an array, markdown style
|
|
tasks_sort = sorted(
|
|
set(t for v in results.values() for t in v.keys()), reverse=True
|
|
)
|
|
print(
|
|
"{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort)
|
|
)
|
|
|
|
for att in results.keys():
|
|
print(
|
|
"{0:<20}".format(att)
|
|
+ "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort)
|
|
)
|