First commit
This commit is contained in:
96
pkgs/xformers/benchmarks/LRA/batch_fetch_results.py
Normal file
96
pkgs/xformers/benchmarks/LRA/batch_fetch_results.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Get the user requests
|
||||
parser = argparse.ArgumentParser(
|
||||
"Collect results from a given batch of distributed results"
|
||||
)
|
||||
parser.add_argument("-ck", "--checkpoint_path", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
# Go through all the data in the given repo, try to find the end results
|
||||
root = Path(args.checkpoint_path)
|
||||
|
||||
# - list all the mechanisms being benchmarked
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
for attention in filter(lambda x: x.is_dir(), root.iterdir()):
|
||||
logging.info(f"\nFound results for {attention.stem}")
|
||||
task_jsons = attention.glob("*/test_eval_summary.json")
|
||||
results[attention.stem] = {}
|
||||
|
||||
for task in task_jsons:
|
||||
task_name = task.stem.split("__")[0]
|
||||
logging.info(f"Logs found for task: {task_name}")
|
||||
results[attention.stem][task_name] = -1
|
||||
found_result = False
|
||||
|
||||
# - collect the individual results
|
||||
with open(task, "r") as result_file:
|
||||
dct = json.load(result_file)
|
||||
if "test_accu_mean" in dct:
|
||||
found_result = True
|
||||
results[attention.stem][task_name] = dct["test_accu_mean"]
|
||||
|
||||
logging.info(
|
||||
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
|
||||
f"{results[attention.stem][task_name]}"
|
||||
)
|
||||
else:
|
||||
break
|
||||
|
||||
# - report an error if no result was found
|
||||
if not found_result:
|
||||
ERR_TAIL = 30
|
||||
|
||||
logging.warning(
|
||||
f"No result found for {task_name}, showing the error log in {task.parent}"
|
||||
)
|
||||
err_log = Path(task.parent).glob("*.err")
|
||||
print("*****************************************************")
|
||||
with open(next(err_log), "r") as err_file:
|
||||
for i, line in enumerate(reversed(err_file.readlines())):
|
||||
print(line, end="")
|
||||
if i > ERR_TAIL:
|
||||
break
|
||||
print("*****************************************************")
|
||||
|
||||
logging.info(f"\nCollected results: {json.dumps(results, indent=2)}")
|
||||
|
||||
# - reduction: compute the average
|
||||
tasks = set(t for v in results.values() for t in v.keys())
|
||||
# -- fill in the possible gaps
|
||||
for att in results.keys():
|
||||
for t in tasks:
|
||||
if t not in results[att].keys():
|
||||
results[att][t] = 0.0
|
||||
|
||||
# -- add the average value
|
||||
for att in results.keys():
|
||||
results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2)
|
||||
|
||||
# - Format as an array, markdown style
|
||||
tasks_sort = sorted(
|
||||
set(t for v in results.values() for t in v.keys()), reverse=True
|
||||
)
|
||||
print(
|
||||
"{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort)
|
||||
)
|
||||
|
||||
for att in results.keys():
|
||||
print(
|
||||
"{0:<20}".format(att)
|
||||
+ "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort)
|
||||
)
|
||||
Reference in New Issue
Block a user