Files
enginex-bi_series-vllm/pkgs/xformers/benchmarks/LRA/batch_fetch_results.py
2025-08-05 19:02:46 +08:00

97 lines
3.4 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import json
import logging
from pathlib import Path
from typing import Any, Dict
if __name__ == "__main__":
# Get the user requests
parser = argparse.ArgumentParser(
"Collect results from a given batch of distributed results"
)
parser.add_argument("-ck", "--checkpoint_path", required=True)
args = parser.parse_args()
logging.getLogger().setLevel(logging.INFO)
# Go through all the data in the given repo, try to find the end results
root = Path(args.checkpoint_path)
# - list all the mechanisms being benchmarked
results: Dict[str, Any] = {}
for attention in filter(lambda x: x.is_dir(), root.iterdir()):
logging.info(f"\nFound results for {attention.stem}")
task_jsons = attention.glob("*/test_eval_summary.json")
results[attention.stem] = {}
for task in task_jsons:
task_name = task.stem.split("__")[0]
logging.info(f"Logs found for task: {task_name}")
results[attention.stem][task_name] = -1
found_result = False
# - collect the individual results
with open(task, "r") as result_file:
dct = json.load(result_file)
if "test_accu_mean" in dct:
found_result = True
results[attention.stem][task_name] = dct["test_accu_mean"]
logging.info(
f"Final result found for {task_name} at epoch {dct['train_step_idx']}: "
f"{results[attention.stem][task_name]}"
)
else:
break
# - report an error if no result was found
if not found_result:
ERR_TAIL = 30
logging.warning(
f"No result found for {task_name}, showing the error log in {task.parent}"
)
err_log = Path(task.parent).glob("*.err")
print("*****************************************************")
with open(next(err_log), "r") as err_file:
for i, line in enumerate(reversed(err_file.readlines())):
print(line, end="")
if i > ERR_TAIL:
break
print("*****************************************************")
logging.info(f"\nCollected results: {json.dumps(results, indent=2)}")
# - reduction: compute the average
tasks = set(t for v in results.values() for t in v.keys())
# -- fill in the possible gaps
for att in results.keys():
for t in tasks:
if t not in results[att].keys():
results[att][t] = 0.0
# -- add the average value
for att in results.keys():
results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2)
# - Format as an array, markdown style
tasks_sort = sorted(
set(t for v in results.values() for t in v.keys()), reverse=True
)
print(
"{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort)
)
for att in results.keys():
print(
"{0:<20}".format(att)
+ "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort)
)