init
This commit is contained in:
295
transformers/utils/add_dates.py
Normal file
295
transformers/utils/add_dates.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from huggingface_hub import paper_info
|
||||
|
||||
|
||||
ROOT = os.getcwd().split("utils")[0]
|
||||
DOCS_PATH = os.path.join(ROOT, "docs/source/en/model_doc")
|
||||
MODELS_PATH = os.path.join(ROOT, "src/transformers/models")
|
||||
|
||||
COPYRIGHT_DISCLAIMER = """<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->"""
|
||||
|
||||
ARXIV_PAPERS_NOT_IN_HF_PAPERS = {
|
||||
"gemma3n.md": "2506.06644",
|
||||
"xmod.md": "2205.06266",
|
||||
}
|
||||
|
||||
|
||||
def get_modified_cards() -> list[str]:
|
||||
"""Get the list of model names from modified files in docs/source/en/model_doc/"""
|
||||
|
||||
result = subprocess.check_output(["git", "diff", "--name-only", "upstream/main"], text=True)
|
||||
|
||||
model_names = []
|
||||
for line in result.strip().split("\n"):
|
||||
if line:
|
||||
# Check if the file is in the model_doc directory
|
||||
if line.startswith("docs/source/en/model_doc/") and line.endswith(".md"):
|
||||
model_name = os.path.splitext(os.path.basename(line))[0]
|
||||
if model_name not in ["auto", "timm_wrapper"]:
|
||||
model_names.append(model_name)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
def get_paper_link(model_card: Optional[str], path: Optional[str]) -> str:
|
||||
"""Get the first paper link from the model card content."""
|
||||
|
||||
if model_card is not None and not model_card.endswith(".md"):
|
||||
model_card = f"{model_card}.md"
|
||||
file_path = path or os.path.join(DOCS_PATH, f"{model_card}")
|
||||
model_card = os.path.basename(file_path)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# Find known paper links
|
||||
paper_ids = re.findall(r"https://huggingface\.co/papers/\d+\.\d+", content)
|
||||
paper_ids += re.findall(r"https://arxiv\.org/abs/\d+\.\d+", content)
|
||||
paper_ids += re.findall(r"https://arxiv\.org/pdf/\d+\.\d+", content)
|
||||
|
||||
# If no known paper links are found, look for other potential paper links
|
||||
if len(paper_ids) == 0:
|
||||
# Find all https links
|
||||
all_https_links = re.findall(r"https://[^\s\)]+", content)
|
||||
|
||||
# Filter out huggingface.co and github links
|
||||
other_paper_links = []
|
||||
for link in all_https_links:
|
||||
link = link.rstrip(".,;!?)")
|
||||
if "huggingface.co" not in link and "github.com" not in link:
|
||||
other_paper_links.append(link)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
other_paper_links = list(dict.fromkeys(other_paper_links))
|
||||
|
||||
if other_paper_links:
|
||||
print(f"No Hugging Face or Arxiv papers found. The possible paper links found in {model_card}:")
|
||||
for link in other_paper_links:
|
||||
print(f" - {link}")
|
||||
|
||||
return "No_paper"
|
||||
|
||||
return paper_ids[0]
|
||||
|
||||
|
||||
def get_first_commit_date(model_name: Optional[str]) -> str:
|
||||
"""Get the first commit date of the model's init file or model.md. This date is considered as the date the model was added to HF transformers"""
|
||||
|
||||
if model_name.endswith(".md"):
|
||||
model_name = f"{model_name[:-3]}"
|
||||
|
||||
model_name_src = model_name
|
||||
if "-" in model_name:
|
||||
model_name_src = model_name.replace("-", "_")
|
||||
file_path = os.path.join(MODELS_PATH, model_name_src, "__init__.py")
|
||||
|
||||
# If the init file is not found (only true for legacy models), the doc's first commit date is used
|
||||
if not os.path.exists(file_path):
|
||||
file_path = os.path.join(DOCS_PATH, f"{model_name}.md")
|
||||
|
||||
# Check if file exists in upstream/main
|
||||
result_main = subprocess.check_output(
|
||||
["git", "ls-tree", "upstream/main", "--", file_path], text=True, stderr=subprocess.DEVNULL
|
||||
)
|
||||
if not result_main:
|
||||
# File does not exist in upstream/main (new model), use today's date
|
||||
final_date = date.today().isoformat()
|
||||
else:
|
||||
# File exists in upstream/main, get the first commit date
|
||||
final_date = subprocess.check_output(
|
||||
["git", "log", "--reverse", "--pretty=format:%ad", "--date=iso", file_path], text=True
|
||||
)
|
||||
return final_date.strip().split("\n")[0][:10]
|
||||
|
||||
|
||||
def get_release_date(link: str) -> str:
|
||||
if link.startswith("https://huggingface.co/papers/"):
|
||||
link = link.replace("https://huggingface.co/papers/", "")
|
||||
|
||||
try:
|
||||
info = paper_info(link)
|
||||
return info.published_at.date().isoformat()
|
||||
except Exception as e:
|
||||
print(f"Error fetching release date for the paper https://huggingface.co/papers/{link}: {e}")
|
||||
|
||||
elif link.startswith("https://arxiv.org/abs/") or link.startswith("https://arxiv.org/pdf/"):
|
||||
print(f"This paper {link} is not yet available in Hugging Face papers, skipping the release date attachment.")
|
||||
return r"{release_date}"
|
||||
|
||||
|
||||
def replace_paper_links(file_path: str) -> bool:
|
||||
"""Replace arxiv links with huggingface links if valid, and replace hf.co with huggingface.co"""
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
model_card = os.path.basename(file_path)
|
||||
original_content = content
|
||||
|
||||
# Replace hf.co with huggingface.co
|
||||
content = content.replace("https://hf.co/", "https://huggingface.co/")
|
||||
|
||||
# Find all arxiv links
|
||||
arxiv_links = re.findall(r"https://arxiv\.org/abs/(\d+\.\d+)", content)
|
||||
arxiv_links += re.findall(r"https://arxiv\.org/pdf/(\d+\.\d+)", content)
|
||||
|
||||
for paper_id in arxiv_links:
|
||||
try:
|
||||
# Check if paper exists on huggingface
|
||||
paper_info(paper_id)
|
||||
# If no exception, replace the link
|
||||
old_link = f"https://arxiv.org/abs/{paper_id}"
|
||||
if old_link not in content:
|
||||
old_link = f"https://arxiv.org/pdf/{paper_id}"
|
||||
new_link = f"https://huggingface.co/papers/{paper_id}"
|
||||
content = content.replace(old_link, new_link)
|
||||
print(f"Replaced {old_link} with {new_link}")
|
||||
|
||||
except Exception:
|
||||
# Paper not available on huggingface, keep arxiv link
|
||||
print(f"Paper {paper_id} for {model_card} is not available on huggingface, keeping the arxiv link")
|
||||
continue
|
||||
|
||||
# Write back only if content changed
|
||||
if content != original_content:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def insert_dates(model_card_list: list[str]):
|
||||
"""Insert release and commit dates into model cards"""
|
||||
|
||||
for model_card in model_card_list:
|
||||
if not model_card.endswith(".md"):
|
||||
model_card = f"{model_card}.md"
|
||||
|
||||
if model_card == "auto.md" or model_card == "timm_wrapper.md":
|
||||
continue
|
||||
|
||||
file_path = os.path.join(DOCS_PATH, model_card)
|
||||
|
||||
# First replace arxiv paper links with hf paper link if possible
|
||||
links_replaced = replace_paper_links(file_path)
|
||||
if links_replaced:
|
||||
print(f"Updated paper links in {model_card}")
|
||||
|
||||
pattern = (
|
||||
r"\n\*This model was released on (.*) and added to Hugging Face Transformers on (\d{4}-\d{2}-\d{2})\.\*"
|
||||
)
|
||||
|
||||
# Check if the copyright disclaimer sections exists, if not, add one with 2025
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
markers = list(re.finditer(r"-->", content)) # Dates info is placed right below this marker
|
||||
if len(markers) == 0:
|
||||
print(f"No marker found in {model_card}. Adding copyright disclaimer to the top.")
|
||||
|
||||
# Add copyright disclaimer to the very top of the file
|
||||
content = COPYRIGHT_DISCLAIMER + "\n\n" + content
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
markers = list(re.finditer(r"-->", content))
|
||||
|
||||
hf_commit_date = get_first_commit_date(model_name=model_card)
|
||||
|
||||
paper_link = get_paper_link(model_card=model_card, path=file_path)
|
||||
release_date = ""
|
||||
if not (paper_link == "No_paper" or paper_link == "blog"):
|
||||
release_date = get_release_date(paper_link)
|
||||
else:
|
||||
release_date = r"{release_date}"
|
||||
|
||||
match = re.search(pattern, content)
|
||||
|
||||
# If the dates info line already exists, preserve the existing release date unless it's a placeholder, and update the HF commit date if needed
|
||||
if match:
|
||||
existing_release_date = match.group(1) # The release date part
|
||||
existing_hf_date = match.group(2) # The existing HF date part
|
||||
release_date = (
|
||||
release_date
|
||||
if (existing_release_date == r"{release_date}" or existing_release_date == "None")
|
||||
else existing_release_date
|
||||
)
|
||||
if existing_hf_date != hf_commit_date or existing_release_date != release_date:
|
||||
old_line = match.group(0) # Full matched line
|
||||
new_line = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*"
|
||||
|
||||
content = content.replace(old_line, new_line)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
# If the dates info line does not exist, add it
|
||||
else:
|
||||
insert_index = markers[0].end()
|
||||
|
||||
date_info = f"\n*This model was released on {release_date} and added to Hugging Face Transformers on {hf_commit_date}.*"
|
||||
content = content[:insert_index] + date_info + content[insert_index:]
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
print(f"Added {model_card} release and commit dates.")
|
||||
|
||||
|
||||
def get_all_model_cards():
|
||||
"""Get all model cards from the docs path"""
|
||||
|
||||
all_files = os.listdir(DOCS_PATH)
|
||||
model_cards = []
|
||||
for file in all_files:
|
||||
if file.endswith(".md"):
|
||||
model_name = os.path.splitext(file)[0]
|
||||
if model_name not in ["auto", "timm_wrapper"]:
|
||||
model_cards.append(model_name)
|
||||
return sorted(model_cards)
|
||||
|
||||
|
||||
def main(all=False, auto=True, models=None):
|
||||
if all:
|
||||
model_cards = get_all_model_cards()
|
||||
print(f"Processing all {len(model_cards)} model cards from docs directory")
|
||||
elif auto:
|
||||
model_cards = get_modified_cards()
|
||||
if not model_cards:
|
||||
print("No modified model cards found.")
|
||||
return
|
||||
print(f"Processing modified model cards: {model_cards}")
|
||||
else:
|
||||
model_cards = models
|
||||
print(f"Processing specified model cards: {model_cards}")
|
||||
|
||||
insert_dates(model_cards)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Add release and commit dates to model cards")
|
||||
group = parser.add_mutually_exclusive_group(required=False)
|
||||
group.add_argument(
|
||||
"--auto", action="store_true", help="Automatically process modified model cards from git status"
|
||||
)
|
||||
group.add_argument("--models", nargs="+", help="Specify model cards to process (without .md extension)")
|
||||
group.add_argument("--all", action="store_true", help="Process all model cards in the docs directory")
|
||||
|
||||
parser.set_defaults(auto=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.all, args.auto, args.models)
|
||||
308
transformers/utils/add_pipeline_model_mapping_to_test.py
Normal file
308
transformers/utils/add_pipeline_model_mapping_to_test.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A script to add and/or update the attribute `pipeline_model_mapping` in model test files.
|
||||
|
||||
This script will be (mostly) used in the following 2 situations:
|
||||
|
||||
- run within a (scheduled) CI job to:
|
||||
- check if model test files in the library have updated `pipeline_model_mapping`,
|
||||
- and/or update test files and (possibly) open a GitHub pull request automatically
|
||||
- being run by a `transformers` member to quickly check and update some particular test file(s)
|
||||
|
||||
This script is **NOT** intended to be run (manually) by community contributors.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from get_test_info import get_test_classes
|
||||
|
||||
from tests.test_pipeline_mixin import pipeline_test_mapping
|
||||
|
||||
|
||||
PIPELINE_TEST_MAPPING = {}
|
||||
for task in pipeline_test_mapping:
|
||||
PIPELINE_TEST_MAPPING[task] = None
|
||||
|
||||
|
||||
# DO **NOT** add item to this set (unless the reason is approved)
|
||||
TEST_FILE_TO_IGNORE = {
|
||||
"tests/models/esm/test_modeling_esmfold.py", # The pipeline test mapping is added to `test_modeling_esm.py`
|
||||
}
|
||||
|
||||
|
||||
def get_mapping_for_task(task):
|
||||
"""Get mappings defined in `XXXPipelineTests` for the task `task`."""
|
||||
# Use the cached results
|
||||
if PIPELINE_TEST_MAPPING[task] is not None:
|
||||
return PIPELINE_TEST_MAPPING[task]
|
||||
|
||||
pipeline_test_class = pipeline_test_mapping[task]["test"]
|
||||
mapping = getattr(pipeline_test_class, "model_mapping", None)
|
||||
|
||||
if mapping is not None:
|
||||
mapping = dict(mapping.items())
|
||||
|
||||
# cache the results
|
||||
PIPELINE_TEST_MAPPING[task] = mapping
|
||||
return mapping
|
||||
|
||||
|
||||
def get_model_for_pipeline_test(test_class, task):
|
||||
"""Get the model architecture(s) related to the test class `test_class` for a pipeline `task`."""
|
||||
mapping = get_mapping_for_task(task)
|
||||
if mapping is None:
|
||||
return None
|
||||
|
||||
config_classes = list({model_class.config_class for model_class in test_class.all_model_classes})
|
||||
if len(config_classes) != 1:
|
||||
raise ValueError("There should be exactly one configuration class from `test_class.all_model_classes`.")
|
||||
|
||||
# This could be a list/tuple of model classes, but it's rare.
|
||||
model_class = mapping.get(config_classes[0], None)
|
||||
if isinstance(model_class, (tuple, list)):
|
||||
model_class = sorted(model_class, key=lambda x: x.__name__)
|
||||
|
||||
return model_class
|
||||
|
||||
|
||||
def get_pipeline_model_mapping(test_class):
|
||||
"""Get `pipeline_model_mapping` for `test_class`."""
|
||||
mapping = [(task, get_model_for_pipeline_test(test_class, task)) for task in pipeline_test_mapping]
|
||||
mapping = sorted([(task, model) for task, model in mapping if model is not None], key=lambda x: x[0])
|
||||
|
||||
return dict(mapping)
|
||||
|
||||
|
||||
def get_pipeline_model_mapping_string(test_class):
|
||||
"""Get `pipeline_model_mapping` for `test_class` as a string (to be added to the test file).
|
||||
|
||||
This will be a 1-line string. After this is added to a test file, `make style` will format it beautifully.
|
||||
"""
|
||||
default_value = "{}"
|
||||
mapping = get_pipeline_model_mapping(test_class)
|
||||
if len(mapping) == 0:
|
||||
return ""
|
||||
|
||||
texts = []
|
||||
for task, model_classes in mapping.items():
|
||||
if isinstance(model_classes, (tuple, list)):
|
||||
# A list/tuple of model classes
|
||||
value = "(" + ", ".join([x.__name__ for x in model_classes]) + ")"
|
||||
else:
|
||||
# A single model class
|
||||
value = model_classes.__name__
|
||||
texts.append(f'"{task}": {value}')
|
||||
text = "{" + ", ".join(texts) + "}"
|
||||
text = f"pipeline_model_mapping = {text} if is_torch_available() else {default_value}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def is_valid_test_class(test_class):
|
||||
"""Restrict to `XXXModelTesterMixin` and should be a subclass of `unittest.TestCase`."""
|
||||
if not issubclass(test_class, unittest.TestCase):
|
||||
return False
|
||||
return "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__]
|
||||
|
||||
|
||||
def find_test_class(test_file):
|
||||
"""Find a test class in `test_file` to which we will add `pipeline_model_mapping`."""
|
||||
test_classes = [x for x in get_test_classes(test_file) if is_valid_test_class(x)]
|
||||
|
||||
target_test_class = None
|
||||
for test_class in test_classes:
|
||||
# If a test class has defined `pipeline_model_mapping`, let's take it
|
||||
if getattr(test_class, "pipeline_model_mapping", None) is not None:
|
||||
target_test_class = test_class
|
||||
break
|
||||
# Take the test class with the shortest name (just a heuristic)
|
||||
if target_test_class is None and len(test_classes) > 0:
|
||||
target_test_class = sorted(test_classes, key=lambda x: (len(x.__name__), x.__name__))[0]
|
||||
|
||||
return target_test_class
|
||||
|
||||
|
||||
def find_block_ending(lines, start_idx, indent_level):
|
||||
end_idx = start_idx
|
||||
for idx, line in enumerate(lines[start_idx:]):
|
||||
indent = len(line) - len(line.lstrip())
|
||||
if idx == 0 or indent > indent_level or (indent == indent_level and line.strip() == ")"):
|
||||
end_idx = start_idx + idx
|
||||
elif idx > 0 and indent <= indent_level:
|
||||
# Outside the definition block of `pipeline_model_mapping`
|
||||
break
|
||||
|
||||
return end_idx
|
||||
|
||||
|
||||
def add_pipeline_model_mapping(test_class, overwrite=False):
|
||||
"""Add `pipeline_model_mapping` to `test_class`."""
|
||||
if getattr(test_class, "pipeline_model_mapping", None) is not None:
|
||||
if not overwrite:
|
||||
return "", -1
|
||||
|
||||
line_to_add = get_pipeline_model_mapping_string(test_class)
|
||||
if len(line_to_add) == 0:
|
||||
return "", -1
|
||||
line_to_add = line_to_add + "\n"
|
||||
|
||||
# The code defined the class `test_class`
|
||||
class_lines, class_start_line_no = inspect.getsourcelines(test_class)
|
||||
# `inspect` gives the code for an object, including decorator(s) if any.
|
||||
# We (only) need the exact line of the class definition.
|
||||
for idx, line in enumerate(class_lines):
|
||||
if line.lstrip().startswith("class "):
|
||||
class_lines = class_lines[idx:]
|
||||
class_start_line_no += idx
|
||||
break
|
||||
class_end_line_no = class_start_line_no + len(class_lines) - 1
|
||||
|
||||
# The index in `class_lines` that starts the definition of `all_model_classes`, `all_generative_model_classes` or
|
||||
# `pipeline_model_mapping`. This assumes they are defined in such order, and we take the start index of the last
|
||||
# block that appears in a `test_class`.
|
||||
start_idx = None
|
||||
# The indent level of the line at `class_lines[start_idx]` (if defined)
|
||||
indent_level = 0
|
||||
# To record if `pipeline_model_mapping` is found in `test_class`.
|
||||
def_line = None
|
||||
for idx, line in enumerate(class_lines):
|
||||
if line.strip().startswith("all_model_classes = "):
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
start_idx = idx
|
||||
elif line.strip().startswith("all_generative_model_classes = "):
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
start_idx = idx
|
||||
elif line.strip().startswith("pipeline_model_mapping = "):
|
||||
indent_level = len(line) - len(line.lstrip())
|
||||
start_idx = idx
|
||||
def_line = line
|
||||
break
|
||||
|
||||
if start_idx is None:
|
||||
return "", -1
|
||||
# Find the ending index (inclusive) of the above found block.
|
||||
end_idx = find_block_ending(class_lines, start_idx, indent_level)
|
||||
|
||||
# Extract `is_xxx_available()` from existing blocks: some models require specific libraries like `timm` and use
|
||||
# `is_timm_available()` instead of `is_torch_available()`.
|
||||
# Keep leading and trailing whitespaces
|
||||
r = re.compile(r"\s(is_\S+?_available\(\))\s")
|
||||
for line in class_lines[start_idx : end_idx + 1]:
|
||||
backend_condition = r.search(line)
|
||||
if backend_condition is not None:
|
||||
# replace the leading and trailing whitespaces to the space character " ".
|
||||
target = " " + backend_condition[0][1:-1] + " "
|
||||
line_to_add = r.sub(target, line_to_add)
|
||||
break
|
||||
|
||||
if def_line is None:
|
||||
# `pipeline_model_mapping` is not defined. The target index is set to the ending index (inclusive) of
|
||||
# `all_model_classes` or `all_generative_model_classes`.
|
||||
target_idx = end_idx
|
||||
else:
|
||||
# `pipeline_model_mapping` is defined. The target index is set to be one **BEFORE** its start index.
|
||||
target_idx = start_idx - 1
|
||||
# mark the lines of the currently existing `pipeline_model_mapping` to be removed.
|
||||
for idx in range(start_idx, end_idx + 1):
|
||||
# These lines are going to be removed before writing to the test file.
|
||||
class_lines[idx] = None # noqa
|
||||
|
||||
# Make sure the test class is a subclass of `PipelineTesterMixin`.
|
||||
parent_classes = [x.__name__ for x in test_class.__bases__]
|
||||
if "PipelineTesterMixin" not in parent_classes:
|
||||
# Put `PipelineTesterMixin` just before `unittest.TestCase`
|
||||
_parent_classes = [x for x in parent_classes if x != "TestCase"] + ["PipelineTesterMixin"]
|
||||
if "TestCase" in parent_classes:
|
||||
# Here we **assume** the original string is always with `unittest.TestCase`.
|
||||
_parent_classes.append("unittest.TestCase")
|
||||
parent_classes = ", ".join(_parent_classes)
|
||||
for idx, line in enumerate(class_lines):
|
||||
# Find the ending of the declaration of `test_class`
|
||||
if line.strip().endswith("):"):
|
||||
# mark the lines of the declaration of `test_class` to be removed
|
||||
for _idx in range(idx + 1):
|
||||
class_lines[_idx] = None # noqa
|
||||
break
|
||||
# Add the new, one-line, class declaration for `test_class`
|
||||
class_lines[0] = f"class {test_class.__name__}({parent_classes}):\n"
|
||||
|
||||
# Add indentation
|
||||
line_to_add = " " * indent_level + line_to_add
|
||||
# Insert `pipeline_model_mapping` to `class_lines`.
|
||||
# (The line at `target_idx` should be kept by definition!)
|
||||
class_lines = class_lines[: target_idx + 1] + [line_to_add] + class_lines[target_idx + 1 :]
|
||||
# Remove the lines that are marked to be removed
|
||||
class_lines = [x for x in class_lines if x is not None]
|
||||
|
||||
# Move from test class to module (in order to write to the test file)
|
||||
module_lines = inspect.getsourcelines(inspect.getmodule(test_class))[0]
|
||||
# Be careful with the 1-off between line numbers and array indices
|
||||
module_lines = module_lines[: class_start_line_no - 1] + class_lines + module_lines[class_end_line_no:]
|
||||
code = "".join(module_lines)
|
||||
|
||||
moddule_file = inspect.getsourcefile(test_class)
|
||||
with open(moddule_file, "w", encoding="UTF-8", newline="\n") as fp:
|
||||
fp.write(code)
|
||||
|
||||
return line_to_add
|
||||
|
||||
|
||||
def add_pipeline_model_mapping_to_test_file(test_file, overwrite=False):
|
||||
"""Add `pipeline_model_mapping` to `test_file`."""
|
||||
test_class = find_test_class(test_file)
|
||||
if test_class:
|
||||
add_pipeline_model_mapping(test_class, overwrite=overwrite)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--test_file", type=str, help="A path to the test file, starting with the repository's `tests` directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all",
|
||||
action="store_true",
|
||||
help="If to check and modify all test files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="If to overwrite a test class if it has already defined `pipeline_model_mapping`.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.all and not args.test_file:
|
||||
raise ValueError("Please specify either `test_file` or pass `--all` to check/modify all test files.")
|
||||
elif args.all and args.test_file:
|
||||
raise ValueError("Only one of `--test_file` and `--all` could be specified.")
|
||||
|
||||
test_files = []
|
||||
if args.test_file:
|
||||
test_files = [args.test_file]
|
||||
else:
|
||||
pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")
|
||||
for test_file in glob.glob(pattern):
|
||||
test_files.append(test_file)
|
||||
|
||||
for test_file in test_files:
|
||||
if test_file in TEST_FILE_TO_IGNORE:
|
||||
print(f"[SKIPPED] {test_file} is skipped as it is in `TEST_FILE_TO_IGNORE` in the file {__file__}.")
|
||||
continue
|
||||
add_pipeline_model_mapping_to_test_file(test_file, overwrite=args.overwrite)
|
||||
220
transformers/utils/check_bad_commit.py
Normal file
220
transformers/utils/check_bad_commit.py
Normal file
@@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def create_script(target_test):
|
||||
"""Create a python script to be run by `git bisect run` to determine if `target_test` passes or fails.
|
||||
If a test is not found in a commit, the script with exit code `0` (i.e. `Success`).
|
||||
|
||||
Args:
|
||||
target_test (`str`): The test to check.
|
||||
|
||||
Returns:
|
||||
`str`: The script to be run by `git bisect run`.
|
||||
"""
|
||||
|
||||
script = f"""
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
result = subprocess.run(
|
||||
["python3", "-m", "pytest", "-v", "-rfEp", f"{target_test}"],
|
||||
capture_output = True,
|
||||
text=True,
|
||||
)
|
||||
print(result.stdout)
|
||||
|
||||
if f"PASSED {target_test}" in result.stdout:
|
||||
print("test passed")
|
||||
exit(0)
|
||||
elif len(result.stderr) > 0:
|
||||
if "ERROR: file or directory not found: " in result.stderr:
|
||||
print("test file or directory not found in this commit")
|
||||
exit(0)
|
||||
elif "ERROR: not found: " in result.stderr:
|
||||
print("test not found in this commit")
|
||||
exit(0)
|
||||
else:
|
||||
print(f"pytest failed to run: {{result.stderr}}")
|
||||
exit(-1)
|
||||
elif f"FAILED {target_test}" in result.stdout:
|
||||
print("test failed")
|
||||
exit(2)
|
||||
|
||||
exit(0)
|
||||
"""
|
||||
|
||||
with open("target_script.py", "w") as fp:
|
||||
fp.write(script.strip())
|
||||
|
||||
|
||||
def find_bad_commit(target_test, start_commit, end_commit):
|
||||
"""Find (backward) the earliest commit between `start_commit` and `end_commit` at which `target_test` fails.
|
||||
|
||||
Args:
|
||||
target_test (`str`): The test to check.
|
||||
start_commit (`str`): The latest commit.
|
||||
end_commit (`str`): The earliest commit.
|
||||
|
||||
Returns:
|
||||
`str`: The earliest commit at which `target_test` fails.
|
||||
"""
|
||||
|
||||
if start_commit == end_commit:
|
||||
return start_commit
|
||||
|
||||
create_script(target_test=target_test)
|
||||
|
||||
bash = f"""
|
||||
git bisect reset
|
||||
git bisect start {start_commit} {end_commit}
|
||||
git bisect run python3 target_script.py
|
||||
"""
|
||||
|
||||
with open("run_git_bisect.sh", "w") as fp:
|
||||
fp.write(bash.strip())
|
||||
|
||||
result = subprocess.run(
|
||||
["bash", "run_git_bisect.sh"],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
print(result.stdout)
|
||||
|
||||
if "error: bisect run failed" in result.stderr:
|
||||
index = result.stderr.find("error: bisect run failed")
|
||||
bash_error = result.stderr[index:]
|
||||
|
||||
error_msg = f"Error when running git bisect:\nbash error: {bash_error}"
|
||||
|
||||
pattern = "pytest failed to run: .+"
|
||||
pytest_errors = re.findall(pattern, result.stdout)
|
||||
if len(pytest_errors) > 0:
|
||||
pytest_error = pytest_errors[0]
|
||||
index = pytest_error.find("pytest failed to run: ")
|
||||
index += len("pytest failed to run: ")
|
||||
pytest_error = pytest_error[index:]
|
||||
error_msg += f"pytest error: {pytest_error}"
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
pattern = r"(.+) is the first bad commit"
|
||||
commits = re.findall(pattern, result.stdout)
|
||||
|
||||
bad_commit = None
|
||||
if len(commits) > 0:
|
||||
bad_commit = commits[0]
|
||||
|
||||
print(f"Between `start_commit` {start_commit} and `end_commit` {end_commit}")
|
||||
print(f"bad_commit: {bad_commit}\n")
|
||||
|
||||
return bad_commit
|
||||
|
||||
|
||||
def get_commit_info(commit):
|
||||
"""Get information for a commit via `api.github.com`."""
|
||||
pr_number = None
|
||||
author = None
|
||||
merged_author = None
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit}/pulls"
|
||||
pr_info_for_commit = requests.get(url).json()
|
||||
|
||||
if len(pr_info_for_commit) > 0:
|
||||
pr_number = pr_info_for_commit[0]["number"]
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/pulls/{pr_number}"
|
||||
pr_for_commit = requests.get(url).json()
|
||||
author = pr_for_commit["user"]["login"]
|
||||
if pr_for_commit["merged_by"] is not None:
|
||||
merged_author = pr_for_commit["merged_by"]["login"]
|
||||
|
||||
if author is None:
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/commits/{commit}"
|
||||
commit_info = requests.get(url).json()
|
||||
author = commit_info["author"]["login"]
|
||||
|
||||
return {"commit": commit, "pr_number": pr_number, "author": author, "merged_by": merged_author}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--start_commit", type=str, required=True, help="The latest commit hash to check.")
|
||||
parser.add_argument("--end_commit", type=str, required=True, help="The earliest commit hash to check.")
|
||||
parser.add_argument("--test", type=str, help="The test to check.")
|
||||
parser.add_argument("--file", type=str, help="The report file.")
|
||||
parser.add_argument("--output_file", type=str, required=True, help="The path of the output file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"start_commit: {args.start_commit}")
|
||||
print(f"end_commit: {args.end_commit}")
|
||||
|
||||
# `get_commit_info` uses `requests.get()` to request info. via `api.github.com` without using token.
|
||||
# If there are many new failed tests in a workflow run, this script may fail at some point with `KeyError` at
|
||||
# `pr_number = pr_info_for_commit[0]["number"]` due to the rate limit.
|
||||
# Let's cache the commit info. and reuse them whenever possible.
|
||||
commit_info_cache = {}
|
||||
|
||||
if len({args.test is None, args.file is None}) != 2:
|
||||
raise ValueError("Exactly one argument `test` or `file` must be specified.")
|
||||
|
||||
if args.test is not None:
|
||||
commit = find_bad_commit(target_test=args.test, start_commit=args.start_commit, end_commit=args.end_commit)
|
||||
with open(args.output_file, "w", encoding="UTF-8") as fp:
|
||||
fp.write(f"{args.test}\n{commit}")
|
||||
elif os.path.isfile(args.file):
|
||||
with open(args.file, "r", encoding="UTF-8") as fp:
|
||||
reports = json.load(fp)
|
||||
|
||||
for model in reports:
|
||||
# TODO: make this script able to deal with both `single-gpu` and `multi-gpu` via a new argument.
|
||||
reports[model].pop("multi-gpu", None)
|
||||
failed_tests = reports[model]["single-gpu"]
|
||||
|
||||
failed_tests_with_bad_commits = []
|
||||
for test in failed_tests:
|
||||
commit = find_bad_commit(target_test=test, start_commit=args.start_commit, end_commit=args.end_commit)
|
||||
info = {"test": test, "commit": commit}
|
||||
|
||||
if commit in commit_info_cache:
|
||||
commit_info = commit_info_cache[commit]
|
||||
else:
|
||||
commit_info = get_commit_info(commit)
|
||||
commit_info_cache[commit] = commit_info
|
||||
|
||||
info.update(commit_info)
|
||||
failed_tests_with_bad_commits.append(info)
|
||||
|
||||
# If no single-gpu test failures, remove the key
|
||||
if len(failed_tests_with_bad_commits) > 0:
|
||||
reports[model]["single-gpu"] = failed_tests_with_bad_commits
|
||||
else:
|
||||
reports[model].pop("single-gpu", None)
|
||||
|
||||
# remove the models without any test failure
|
||||
reports = {k: v for k, v in reports.items() if len(v) > 0}
|
||||
|
||||
with open(args.output_file, "w", encoding="UTF-8") as fp:
|
||||
json.dump(reports, fp, ensure_ascii=False, indent=4)
|
||||
49
transformers/utils/check_build.py
Normal file
49
transformers/utils/check_build.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Test all the extensions added in the setup
|
||||
FILES_TO_FIND = [
|
||||
"kernels/rwkv/wkv_cuda.cu",
|
||||
"kernels/rwkv/wkv_op.cpp",
|
||||
"kernels/falcon_mamba/selective_scan_with_ln_interface.py",
|
||||
"kernels/falcon_mamba/__init__.py",
|
||||
"kernels/__init__.py",
|
||||
"models/graphormer/algos_graphormer.pyx",
|
||||
]
|
||||
|
||||
|
||||
def test_custom_files_are_present(transformers_path):
|
||||
# Test all the extensions added in the setup
|
||||
for file in FILES_TO_FIND:
|
||||
if not (transformers_path / file).exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_lib", action="store_true", help="Whether to check the build or the actual package.")
|
||||
args = parser.parse_args()
|
||||
if args.check_lib:
|
||||
transformers_module = importlib.import_module("transformers")
|
||||
transformers_path = Path(transformers_module.__file__).parent
|
||||
else:
|
||||
transformers_path = Path.cwd() / "build/lib/transformers"
|
||||
if not test_custom_files_are_present(transformers_path):
|
||||
raise ValueError("The built release does not contain the custom files. Fix this before going further!")
|
||||
531
transformers/utils/check_config_attributes.py
Normal file
531
transformers/utils/check_config_attributes.py
Normal file
@@ -0,0 +1,531 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_config_docstrings.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
|
||||
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
|
||||
SPECIAL_CASES_TO_ALLOW = {
|
||||
"xLSTMConfig": ["add_out_norm", "chunkwise_kernel", "sequence_kernel", "step_kernel"],
|
||||
"Ernie4_5Config": ["tie_word_embeddings"],
|
||||
"Ernie4_5_MoeConfig": ["tie_word_embeddings"],
|
||||
"Lfm2Config": ["full_attn_idxs", "tie_word_embeddings"],
|
||||
# used internally during generation to provide the custom logit processors with their necessary information
|
||||
"DiaConfig": [
|
||||
"delay_pattern",
|
||||
],
|
||||
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
|
||||
# periods and offsets are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
|
||||
"BambaConfig": [
|
||||
"attn_layer_indices",
|
||||
],
|
||||
"Dots1Config": ["max_window_layers"],
|
||||
"JambaConfig": [
|
||||
"max_position_embeddings",
|
||||
"attn_layer_offset",
|
||||
"attn_layer_period",
|
||||
"expert_layer_offset",
|
||||
"expert_layer_period",
|
||||
],
|
||||
"Qwen2Config": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2MoeConfig": ["use_sliding_window"],
|
||||
"Qwen2VLTextConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2_5_VLTextConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2_5OmniTextConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen2_5OmniTalkerConfig": ["use_sliding_window", "max_window_layers"],
|
||||
"Qwen3Config": ["max_window_layers", "use_sliding_window"], # now use `layer_types` instead
|
||||
"Qwen3MoeConfig": ["max_window_layers", "use_sliding_window"],
|
||||
# `cache_implementation` should be in the default generation config, but we don't yet support per-model
|
||||
# generation configs (TODO joao)
|
||||
"Gemma2Config": ["tie_word_embeddings", "cache_implementation"],
|
||||
"Cohere2Config": ["cache_implementation"],
|
||||
# Dropout with this value was declared but never used
|
||||
"Phi3Config": ["embd_pdrop"],
|
||||
# used to compute the property `self.chunk_length`
|
||||
"EncodecConfig": ["overlap"],
|
||||
# used to compute `frame_rate`
|
||||
"XcodecConfig": ["sample_rate", "audio_channels"],
|
||||
# used to compute the property `self.layers_block_type`
|
||||
"RecurrentGemmaConfig": ["block_types"],
|
||||
# used as in the config to define `intermediate_size`
|
||||
"MambaConfig": ["expand"],
|
||||
# used as in the config to define `intermediate_size`
|
||||
"FalconMambaConfig": ["expand"],
|
||||
# used as `self.bert_model = BertModel(config, ...)`
|
||||
"DPRConfig": True,
|
||||
"FuyuConfig": True,
|
||||
# not used in modeling files, but it's an important information
|
||||
"FSMTConfig": ["langs"],
|
||||
# used internally in the configuration class file
|
||||
"GPTNeoConfig": ["attention_types"],
|
||||
# used internally in the configuration class file
|
||||
"EsmConfig": ["is_folding_model"],
|
||||
# used during training (despite we don't have training script for these models yet)
|
||||
"Mask2FormerConfig": ["ignore_value"],
|
||||
# `ignore_value` used during training (despite we don't have training script for these models yet)
|
||||
# `norm` used in conversion script (despite not using in the modeling file)
|
||||
"OneFormerConfig": ["ignore_value", "norm"],
|
||||
# used internally in the configuration class file
|
||||
"T5Config": ["feed_forward_proj"],
|
||||
# used internally in the configuration class file
|
||||
# `tokenizer_class` get default value `T5Tokenizer` intentionally
|
||||
"MT5Config": ["feed_forward_proj", "tokenizer_class"],
|
||||
"UMT5Config": ["feed_forward_proj", "tokenizer_class"],
|
||||
# used internally in the configuration class file
|
||||
"LongT5Config": ["feed_forward_proj"],
|
||||
# used internally in the configuration class file
|
||||
"Pop2PianoConfig": ["feed_forward_proj"],
|
||||
# used internally in the configuration class file
|
||||
"SwitchTransformersConfig": ["feed_forward_proj"],
|
||||
# having default values other than `1e-5` - we can't fix them without breaking
|
||||
"BioGptConfig": ["layer_norm_eps"],
|
||||
# having default values other than `1e-5` - we can't fix them without breaking
|
||||
"GLPNConfig": ["layer_norm_eps"],
|
||||
# having default values other than `1e-5` - we can't fix them without breaking
|
||||
"SegformerConfig": ["layer_norm_eps"],
|
||||
# having default values other than `1e-5` - we can't fix them without breaking
|
||||
"CvtConfig": ["layer_norm_eps"],
|
||||
# having default values other than `1e-5` - we can't fix them without breaking
|
||||
"PerceiverConfig": ["layer_norm_eps"],
|
||||
# used internally to calculate the feature size
|
||||
"InformerConfig": ["num_static_real_features", "num_time_features"],
|
||||
# used internally to calculate the feature size
|
||||
"TimeSeriesTransformerConfig": ["num_static_real_features", "num_time_features"],
|
||||
# used internally to calculate the feature size
|
||||
"AutoformerConfig": ["num_static_real_features", "num_time_features"],
|
||||
# used internally to calculate `mlp_dim`
|
||||
"SamVisionConfig": ["mlp_ratio"],
|
||||
# used internally to calculate `mlp_dim`
|
||||
"SamHQVisionConfig": ["mlp_ratio"],
|
||||
# For (head) training, but so far not implemented
|
||||
"ClapAudioConfig": ["num_classes"],
|
||||
# Not used, but providing useful information to users
|
||||
"SpeechT5HifiGanConfig": ["sampling_rate"],
|
||||
# used internally in the configuration class file
|
||||
"UdopConfig": ["feed_forward_proj"],
|
||||
# Actually used in the config or generation config, in that case necessary for the sub-components generation
|
||||
"SeamlessM4TConfig": [
|
||||
"max_new_tokens",
|
||||
"t2u_max_new_tokens",
|
||||
"t2u_decoder_attention_heads",
|
||||
"t2u_decoder_ffn_dim",
|
||||
"t2u_decoder_layers",
|
||||
"t2u_encoder_attention_heads",
|
||||
"t2u_encoder_ffn_dim",
|
||||
"t2u_encoder_layers",
|
||||
"t2u_max_position_embeddings",
|
||||
],
|
||||
# Actually used in the config or generation config, in that case necessary for the sub-components generation
|
||||
"SeamlessM4Tv2Config": [
|
||||
"max_new_tokens",
|
||||
"t2u_decoder_attention_heads",
|
||||
"t2u_decoder_ffn_dim",
|
||||
"t2u_decoder_layers",
|
||||
"t2u_encoder_attention_heads",
|
||||
"t2u_encoder_ffn_dim",
|
||||
"t2u_encoder_layers",
|
||||
"t2u_max_position_embeddings",
|
||||
"t2u_variance_pred_dropout",
|
||||
"t2u_variance_predictor_embed_dim",
|
||||
"t2u_variance_predictor_hidden_dim",
|
||||
"t2u_variance_predictor_kernel_size",
|
||||
],
|
||||
"ZambaConfig": [
|
||||
"tie_word_embeddings",
|
||||
"attn_layer_offset",
|
||||
"attn_layer_period",
|
||||
],
|
||||
"MllamaTextConfig": [
|
||||
"initializer_range",
|
||||
],
|
||||
"MllamaVisionConfig": [
|
||||
"initializer_range",
|
||||
"supported_aspect_ratios",
|
||||
],
|
||||
"ConditionalDetrConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"cls_loss_coefficient",
|
||||
"dice_loss_coefficient",
|
||||
"focal_alpha",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
"mask_loss_coefficient",
|
||||
],
|
||||
"DabDetrConfig": [
|
||||
"dilation",
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"cls_loss_coefficient",
|
||||
"focal_alpha",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
],
|
||||
"DetrConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"dice_loss_coefficient",
|
||||
"eos_coefficient",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
"mask_loss_coefficient",
|
||||
],
|
||||
"DFineConfig": [
|
||||
"eos_coefficient",
|
||||
"focal_loss_alpha",
|
||||
"focal_loss_gamma",
|
||||
"matcher_alpha",
|
||||
"matcher_bbox_cost",
|
||||
"matcher_class_cost",
|
||||
"matcher_gamma",
|
||||
"matcher_giou_cost",
|
||||
"use_focal_loss",
|
||||
"weight_loss_bbox",
|
||||
"weight_loss_giou",
|
||||
"weight_loss_vfl",
|
||||
"weight_loss_fgl",
|
||||
"weight_loss_ddf",
|
||||
],
|
||||
"GroundingDinoConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"focal_alpha",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
],
|
||||
"MMGroundingDinoConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"focal_alpha",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
],
|
||||
"RTDetrConfig": [
|
||||
"eos_coefficient",
|
||||
"focal_loss_alpha",
|
||||
"focal_loss_gamma",
|
||||
"matcher_alpha",
|
||||
"matcher_bbox_cost",
|
||||
"matcher_class_cost",
|
||||
"matcher_gamma",
|
||||
"matcher_giou_cost",
|
||||
"use_focal_loss",
|
||||
"weight_loss_bbox",
|
||||
"weight_loss_giou",
|
||||
"weight_loss_vfl",
|
||||
],
|
||||
"RTDetrV2Config": [
|
||||
"eos_coefficient",
|
||||
"focal_loss_alpha",
|
||||
"focal_loss_gamma",
|
||||
"matcher_alpha",
|
||||
"matcher_bbox_cost",
|
||||
"matcher_class_cost",
|
||||
"matcher_gamma",
|
||||
"matcher_giou_cost",
|
||||
"use_focal_loss",
|
||||
"weight_loss_bbox",
|
||||
"weight_loss_giou",
|
||||
"weight_loss_vfl",
|
||||
],
|
||||
"YolosConfig": [
|
||||
"bbox_cost",
|
||||
"bbox_loss_coefficient",
|
||||
"class_cost",
|
||||
"eos_coefficient",
|
||||
"giou_cost",
|
||||
"giou_loss_coefficient",
|
||||
],
|
||||
"GPTNeoXConfig": ["rotary_emb_base"],
|
||||
"Gemma3Config": ["boi_token_index", "eoi_token_index"],
|
||||
"Gemma3TextConfig": ["cache_implementation", "tie_word_embeddings"],
|
||||
"ShieldGemma2Config": [
|
||||
"boi_token_index",
|
||||
"eoi_token_index",
|
||||
"initializer_range",
|
||||
"mm_tokens_per_image",
|
||||
"text_config",
|
||||
"vision_config",
|
||||
],
|
||||
"Llama4Config": ["boi_token_index", "eoi_token_index"],
|
||||
"Llama4TextConfig": [
|
||||
"interleave_moe_layer_step",
|
||||
"no_rope_layer_interval",
|
||||
"no_rope_layers",
|
||||
"output_router_logits",
|
||||
"router_aux_loss_coef",
|
||||
"router_jitter_noise",
|
||||
"cache_implementation",
|
||||
"attention_chunk_size",
|
||||
],
|
||||
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
|
||||
"ModernBertDecoderConfig": [
|
||||
"embedding_dropout",
|
||||
"hidden_activation",
|
||||
"initializer_cutoff_factor",
|
||||
"intermediate_size",
|
||||
"max_position_embeddings",
|
||||
"mlp_bias",
|
||||
"mlp_dropout",
|
||||
"classifier_activation",
|
||||
"global_attn_every_n_layers",
|
||||
"local_attention",
|
||||
"local_rope_theta",
|
||||
],
|
||||
"SmolLM3Config": ["no_rope_layer_interval"],
|
||||
"Gemma3nVisionConfig": ["architecture", "do_pooling", "model_args"], # this is for use in `timm`
|
||||
"VaultGemmaConfig": ["tie_word_embeddings"],
|
||||
"GemmaConfig": ["tie_word_embeddings"],
|
||||
}
|
||||
|
||||
|
||||
# TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
|
||||
SPECIAL_CASES_TO_ALLOW.update(
|
||||
{
|
||||
"CLIPSegConfig": True,
|
||||
"DeformableDetrConfig": True,
|
||||
"DinatConfig": True,
|
||||
"DonutSwinConfig": True,
|
||||
"FastSpeech2ConformerConfig": True,
|
||||
"FSMTConfig": True,
|
||||
"LayoutLMv2Config": True,
|
||||
"MaskFormerSwinConfig": True,
|
||||
"MT5Config": True,
|
||||
# For backward compatibility with trust remote code models
|
||||
"MptConfig": True,
|
||||
"MptAttentionConfig": True,
|
||||
"OneFormerConfig": True,
|
||||
"PerceiverConfig": True,
|
||||
"RagConfig": True,
|
||||
"SpeechT5Config": True,
|
||||
"SwinConfig": True,
|
||||
"Swin2SRConfig": True,
|
||||
"Swinv2Config": True,
|
||||
"SwitchTransformersConfig": True,
|
||||
"TableTransformerConfig": True,
|
||||
"TapasConfig": True,
|
||||
"UniSpeechConfig": True,
|
||||
"UniSpeechSatConfig": True,
|
||||
"WavLMConfig": True,
|
||||
"WhisperConfig": True,
|
||||
# TODO: @Arthur (for `alignment_head` and `alignment_layer`)
|
||||
"JukeboxPriorConfig": True,
|
||||
# TODO: @Younes (for `is_decoder`)
|
||||
"Pix2StructTextConfig": True,
|
||||
"IdeficsConfig": True,
|
||||
"IdeficsVisionConfig": True,
|
||||
"IdeficsPerceiverConfig": True,
|
||||
# TODO: @Arthur/Joao (`hidden_act` unused)
|
||||
"GptOssConfig": True,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def check_attribute_being_used(config_class, attributes, default_value, source_strings):
|
||||
"""Check if any name in `attributes` is used in one of the strings in `source_strings`
|
||||
|
||||
Args:
|
||||
config_class (`type`):
|
||||
The configuration class for which the arguments in its `__init__` will be checked.
|
||||
attributes (`List[str]`):
|
||||
The name of an argument (or attribute) and its variant names if any.
|
||||
default_value (`Any`):
|
||||
A default value for the attribute in `attributes` assigned in the `__init__` of `config_class`.
|
||||
source_strings (`List[str]`):
|
||||
The python source code strings in the same modeling directory where `config_class` is defined. The file
|
||||
containing the definition of `config_class` should be excluded.
|
||||
"""
|
||||
attribute_used = False
|
||||
for attribute in attributes:
|
||||
for modeling_source in source_strings:
|
||||
# check if we can find `config.xxx`, `getattr(config, "xxx", ...)` or `getattr(self.config, "xxx", ...)`
|
||||
if (
|
||||
f"config.{attribute}" in modeling_source
|
||||
or f'getattr(config, "{attribute}"' in modeling_source
|
||||
or f'getattr(self.config, "{attribute}"' in modeling_source
|
||||
or (
|
||||
"TextConfig" in config_class.__name__
|
||||
and f"config.get_text_config().{attribute}" in modeling_source
|
||||
)
|
||||
):
|
||||
attribute_used = True
|
||||
# Deal with multi-line cases
|
||||
elif (
|
||||
re.search(
|
||||
rf'getattr[ \t\v\n\r\f]*\([ \t\v\n\r\f]*(self\.)?config,[ \t\v\n\r\f]*"{attribute}"',
|
||||
modeling_source,
|
||||
)
|
||||
is not None
|
||||
):
|
||||
attribute_used = True
|
||||
if attribute_used:
|
||||
break
|
||||
if attribute_used:
|
||||
break
|
||||
|
||||
# common and important attributes, even if they do not always appear in the modeling files
|
||||
attributes_to_allow = [
|
||||
"initializer_range",
|
||||
"bos_index",
|
||||
"eos_index",
|
||||
"pad_index",
|
||||
"unk_index",
|
||||
"mask_index",
|
||||
"image_token_id", # for VLMs
|
||||
"video_token_id",
|
||||
"image_seq_length",
|
||||
"video_seq_length",
|
||||
"image_size",
|
||||
"text_config", # may appear as `get_text_config()`
|
||||
"use_cache",
|
||||
"out_features",
|
||||
"out_indices",
|
||||
"sampling_rate",
|
||||
# backbone related arguments passed to load_backbone
|
||||
"use_pretrained_backbone",
|
||||
"backbone",
|
||||
"backbone_config",
|
||||
"use_timm_backbone",
|
||||
"backbone_kwargs",
|
||||
# rope attributes may not appear directly in the modeling but are used
|
||||
"rope_theta",
|
||||
"partial_rotary_factor",
|
||||
"pretraining_tp",
|
||||
"boi_token_id",
|
||||
"eoi_token_id",
|
||||
]
|
||||
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
|
||||
|
||||
# Special cases to be allowed
|
||||
case_allowed = True
|
||||
if not attribute_used:
|
||||
case_allowed = False
|
||||
for attribute in attributes:
|
||||
# Allow if the default value in the configuration class is different from the one in `PretrainedConfig`
|
||||
if attribute in ["is_encoder_decoder"] and default_value is True:
|
||||
case_allowed = True
|
||||
elif attribute in ["tie_word_embeddings"] and default_value is False:
|
||||
case_allowed = True
|
||||
|
||||
# Allow cases without checking the default value in the configuration class
|
||||
elif attribute in attributes_to_allow + attributes_used_in_generation:
|
||||
case_allowed = True
|
||||
elif attribute.endswith("_token_id"):
|
||||
case_allowed = True
|
||||
|
||||
# configuration class specific cases
|
||||
if not case_allowed:
|
||||
allowed_cases = SPECIAL_CASES_TO_ALLOW.get(config_class.__name__, [])
|
||||
case_allowed = allowed_cases is True or attribute in allowed_cases
|
||||
|
||||
return attribute_used or case_allowed
|
||||
|
||||
|
||||
def check_config_attributes_being_used(config_class):
|
||||
"""Check the arguments in `__init__` of `config_class` are used in the modeling files in the same directory
|
||||
|
||||
Args:
|
||||
config_class (`type`):
|
||||
The configuration class for which the arguments in its `__init__` will be checked.
|
||||
"""
|
||||
# Get the parameters in `__init__` of the configuration class, and the default values if any
|
||||
signature = dict(inspect.signature(config_class.__init__).parameters)
|
||||
parameter_names = [x for x in list(signature.keys()) if x not in ["self", "kwargs"]]
|
||||
parameter_defaults = [signature[param].default for param in parameter_names]
|
||||
|
||||
# If `attribute_map` exists, an attribute can have different names to be used in the modeling files, and as long
|
||||
# as one variant is used, the test should pass
|
||||
reversed_attribute_map = {}
|
||||
if len(config_class.attribute_map) > 0:
|
||||
reversed_attribute_map = {v: k for k, v in config_class.attribute_map.items()}
|
||||
|
||||
# Get the path to modeling source files
|
||||
config_source_file = inspect.getsourcefile(config_class)
|
||||
model_dir = os.path.dirname(config_source_file)
|
||||
modeling_paths = [os.path.join(model_dir, fn) for fn in os.listdir(model_dir) if fn.startswith("modeling_")]
|
||||
|
||||
# Get the source code strings
|
||||
modeling_sources = []
|
||||
for path in modeling_paths:
|
||||
if os.path.isfile(path):
|
||||
with open(path, encoding="utf8") as fp:
|
||||
modeling_sources.append(fp.read())
|
||||
|
||||
unused_attributes = []
|
||||
for config_param, default_value in zip(parameter_names, parameter_defaults):
|
||||
# `attributes` here is all the variant names for `config_param`
|
||||
attributes = [config_param]
|
||||
# some configuration classes have non-empty `attribute_map`, and both names could be used in the
|
||||
# corresponding modeling files. As long as one of them appears, it is fine.
|
||||
if config_param in reversed_attribute_map:
|
||||
attributes.append(reversed_attribute_map[config_param])
|
||||
|
||||
if not check_attribute_being_used(config_class, attributes, default_value, modeling_sources):
|
||||
unused_attributes.append(attributes[0])
|
||||
|
||||
return sorted(unused_attributes)
|
||||
|
||||
|
||||
def check_config_attributes():
|
||||
"""Check the arguments in `__init__` of all configuration classes are used in python files"""
|
||||
configs_with_unused_attributes = {}
|
||||
for _config_class in list(CONFIG_MAPPING.values()):
|
||||
# Skip deprecated models
|
||||
if "models.deprecated" in _config_class.__module__:
|
||||
continue
|
||||
# Some config classes are not in `CONFIG_MAPPING` (e.g. `CLIPVisionConfig`, `Blip2VisionConfig`, etc.)
|
||||
config_classes_in_module = [
|
||||
cls
|
||||
for name, cls in inspect.getmembers(
|
||||
inspect.getmodule(_config_class),
|
||||
lambda x: inspect.isclass(x)
|
||||
and issubclass(x, PretrainedConfig)
|
||||
and inspect.getmodule(x) == inspect.getmodule(_config_class),
|
||||
)
|
||||
]
|
||||
for config_class in config_classes_in_module:
|
||||
unused_attributes = check_config_attributes_being_used(config_class)
|
||||
if len(unused_attributes) > 0:
|
||||
configs_with_unused_attributes[config_class.__name__] = unused_attributes
|
||||
|
||||
if len(configs_with_unused_attributes) > 0:
|
||||
error = "The following configuration classes contain unused attributes in the corresponding modeling files:\n"
|
||||
for name, attributes in configs_with_unused_attributes.items():
|
||||
error += f"{name}: {attributes}\n"
|
||||
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_config_attributes()
|
||||
103
transformers/utils/check_config_docstrings.py
Normal file
103
transformers/utils/check_config_docstrings.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_config_docstrings.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
|
||||
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
|
||||
|
||||
# Regex pattern used to find the checkpoint mentioned in the docstring of `config_class`.
|
||||
# For example, `[google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased)`
|
||||
_re_checkpoint = re.compile(r"\[(.+?)\]\((https://huggingface\.co/.+?)\)")
|
||||
|
||||
|
||||
CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
|
||||
"DecisionTransformerConfig",
|
||||
"EncoderDecoderConfig",
|
||||
"MusicgenConfig",
|
||||
"RagConfig",
|
||||
"SpeechEncoderDecoderConfig",
|
||||
"TimmBackboneConfig",
|
||||
"TimmWrapperConfig",
|
||||
"VisionEncoderDecoderConfig",
|
||||
"VisionTextDualEncoderConfig",
|
||||
"LlamaConfig",
|
||||
"GraniteConfig",
|
||||
"GraniteMoeConfig",
|
||||
"GraniteMoeHybridConfig",
|
||||
"Qwen3MoeConfig",
|
||||
"GraniteSpeechConfig",
|
||||
}
|
||||
|
||||
|
||||
def get_checkpoint_from_config_class(config_class):
|
||||
checkpoint = None
|
||||
|
||||
# source code of `config_class`
|
||||
config_source = inspect.getsource(config_class)
|
||||
checkpoints = _re_checkpoint.findall(config_source)
|
||||
|
||||
# Each `checkpoint` is a tuple of a checkpoint name and a checkpoint link.
|
||||
# For example, `('google-bert/bert-base-uncased', 'https://huggingface.co/google-bert/bert-base-uncased')`
|
||||
for ckpt_name, ckpt_link in checkpoints:
|
||||
# allow the link to end with `/`
|
||||
if ckpt_link.endswith("/"):
|
||||
ckpt_link = ckpt_link[:-1]
|
||||
|
||||
# verify the checkpoint name corresponds to the checkpoint link
|
||||
ckpt_link_from_name = f"https://huggingface.co/{ckpt_name}"
|
||||
if ckpt_link == ckpt_link_from_name:
|
||||
checkpoint = ckpt_name
|
||||
break
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def check_config_docstrings_have_checkpoints():
|
||||
configs_without_checkpoint = []
|
||||
|
||||
for config_class in list(CONFIG_MAPPING.values()):
|
||||
# Skip deprecated models
|
||||
if "models.deprecated" in config_class.__module__:
|
||||
continue
|
||||
checkpoint = get_checkpoint_from_config_class(config_class)
|
||||
|
||||
name = config_class.__name__
|
||||
if checkpoint is None and name not in CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK:
|
||||
configs_without_checkpoint.append(name)
|
||||
|
||||
if len(configs_without_checkpoint) > 0:
|
||||
message = "\n".join(sorted(configs_without_checkpoint))
|
||||
raise ValueError(
|
||||
f"The following configurations don't contain any valid checkpoint:\n{message}\n\n"
|
||||
"The requirement is to include a link pointing to one of the models of this architecture in the "
|
||||
"docstring of the config classes listed above. The link should have be a markdown format like "
|
||||
"[myorg/mymodel](https://huggingface.co/myorg/mymodel)."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
check_config_docstrings_have_checkpoints()
|
||||
1043
transformers/utils/check_copies.py
Normal file
1043
transformers/utils/check_copies.py
Normal file
File diff suppressed because it is too large
Load Diff
133
transformers/utils/check_doc_toc.py
Normal file
133
transformers/utils/check_doc_toc.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script is responsible for cleaning the model section of the table of content by removing duplicates and sorting
|
||||
the entries in alphabetical order.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check that the table of content is properly sorted (used in `make quality`):
|
||||
|
||||
```bash
|
||||
python utils/check_doc_toc.py
|
||||
```
|
||||
|
||||
Auto-sort the table of content if it is not properly sorted (used in `make style`):
|
||||
|
||||
```bash
|
||||
python utils/check_doc_toc.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
PATH_TO_TOC = "docs/source/en/_toctree.yml"
|
||||
|
||||
|
||||
def clean_model_doc_toc(model_doc: list[dict]) -> list[dict]:
|
||||
"""
|
||||
Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
|
||||
and sorting models alphabetically.
|
||||
|
||||
Args:
|
||||
model_doc (`List[dict]`):
|
||||
The list of dictionaries extracted from the `_toctree.yml` file for this specific modality.
|
||||
|
||||
Returns:
|
||||
`List[dict]`: List of dictionaries like the input, but cleaned up and sorted.
|
||||
"""
|
||||
counts = defaultdict(int)
|
||||
for doc in model_doc:
|
||||
counts[doc["local"]] += 1
|
||||
duplicates = [key for key, value in counts.items() if value > 1]
|
||||
|
||||
new_doc = []
|
||||
for duplicate_key in duplicates:
|
||||
titles = list({doc["title"] for doc in model_doc if doc["local"] == duplicate_key})
|
||||
if len(titles) > 1:
|
||||
raise ValueError(
|
||||
f"{duplicate_key} is present several times in the documentation table of content at "
|
||||
"`docs/source/en/_toctree.yml` with different *Title* values. Choose one of those and remove the "
|
||||
"others."
|
||||
)
|
||||
# Only add this once
|
||||
new_doc.append({"local": duplicate_key, "title": titles[0]})
|
||||
|
||||
# Add none duplicate-keys
|
||||
new_doc.extend([doc for doc in model_doc if counts[doc["local"]] == 1])
|
||||
|
||||
# Sort
|
||||
return sorted(new_doc, key=lambda s: s["title"].lower())
|
||||
|
||||
|
||||
def check_model_doc(overwrite: bool = False):
|
||||
"""
|
||||
Check that the content of the table of content in `_toctree.yml` is clean (no duplicates and sorted for the model
|
||||
API doc) and potentially auto-cleans it.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether to just check if the TOC is clean or to auto-clean it (when `overwrite=True`).
|
||||
"""
|
||||
with open(PATH_TO_TOC, encoding="utf-8") as f:
|
||||
content = yaml.safe_load(f.read())
|
||||
|
||||
# Get to the API doc
|
||||
api_idx = 0
|
||||
while content[api_idx]["title"] != "API":
|
||||
api_idx += 1
|
||||
api_doc = content[api_idx]["sections"]
|
||||
|
||||
# Then to the model doc
|
||||
model_idx = 0
|
||||
while api_doc[model_idx]["title"] != "Models":
|
||||
model_idx += 1
|
||||
|
||||
model_doc = api_doc[model_idx]["sections"]
|
||||
|
||||
# Extract the modalities and clean them one by one.
|
||||
modalities_docs = [(idx, section) for idx, section in enumerate(model_doc) if "sections" in section]
|
||||
diff = False
|
||||
for idx, modality_doc in modalities_docs:
|
||||
old_modality_doc = modality_doc["sections"]
|
||||
new_modality_doc = clean_model_doc_toc(old_modality_doc)
|
||||
|
||||
if old_modality_doc != new_modality_doc:
|
||||
diff = True
|
||||
if overwrite:
|
||||
model_doc[idx]["sections"] = new_modality_doc
|
||||
|
||||
if diff:
|
||||
if overwrite:
|
||||
api_doc[model_idx]["sections"] = model_doc
|
||||
content[api_idx]["sections"] = api_doc
|
||||
with open(PATH_TO_TOC, "w", encoding="utf-8") as f:
|
||||
f.write(yaml.dump(content, allow_unicode=True))
|
||||
else:
|
||||
raise ValueError(
|
||||
"The model doc part of the table of content is not properly sorted, run `make style` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_model_doc(args.fix_and_overwrite)
|
||||
1484
transformers/utils/check_docstrings.py
Normal file
1484
transformers/utils/check_docstrings.py
Normal file
File diff suppressed because it is too large
Load Diff
86
transformers/utils/check_doctest_list.py
Normal file
86
transformers/utils/check_doctest_list.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script is responsible for cleaning the list of doctests by making sure the entries all exist and are in
|
||||
alphabetical order.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check that the doctest list is properly sorted and all files exist (used in `make repo-consistency`):
|
||||
|
||||
```bash
|
||||
python utils/check_doctest_list.py
|
||||
```
|
||||
|
||||
Auto-sort the doctest list if it is not properly sorted (used in `make fix-copies`):
|
||||
|
||||
```bash
|
||||
python utils/check_doctest_list.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_doctest_list.py
|
||||
REPO_PATH = "."
|
||||
DOCTEST_FILE_PATHS = ["not_doctested.txt", "slow_documentation_tests.txt"]
|
||||
|
||||
|
||||
def clean_doctest_list(doctest_file: str, overwrite: bool = False):
|
||||
"""
|
||||
Cleans the doctest in a given file.
|
||||
|
||||
Args:
|
||||
doctest_file (`str`):
|
||||
The path to the doctest file to check or clean.
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to fix problems. If `False`, will error when the file is not clean.
|
||||
"""
|
||||
non_existent_paths = []
|
||||
all_paths = []
|
||||
with open(doctest_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip().split(" ")[0]
|
||||
path = os.path.join(REPO_PATH, line)
|
||||
if not (os.path.isfile(path) or os.path.isdir(path)):
|
||||
non_existent_paths.append(line)
|
||||
all_paths.append(line)
|
||||
|
||||
if len(non_existent_paths) > 0:
|
||||
non_existent_paths = "\n".join([f"- {f}" for f in non_existent_paths])
|
||||
raise ValueError(f"`{doctest_file}` contains non-existent paths:\n{non_existent_paths}")
|
||||
|
||||
sorted_paths = sorted(all_paths)
|
||||
if all_paths != sorted_paths:
|
||||
if not overwrite:
|
||||
raise ValueError(
|
||||
f"Files in `{doctest_file}` are not in alphabetical order, run `make fix-copies` to fix "
|
||||
"this automatically."
|
||||
)
|
||||
with open(doctest_file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(sorted_paths) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
for doctest_file in DOCTEST_FILE_PATHS:
|
||||
doctest_file = os.path.join(REPO_PATH, "utils", doctest_file)
|
||||
clean_doctest_list(doctest_file, args.fix_and_overwrite)
|
||||
256
transformers/utils/check_dummies.py
Normal file
256
transformers/utils/check_dummies.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.
|
||||
|
||||
Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
|
||||
have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
|
||||
to access one of their methods.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check that the dummy files are up to date (used in `make repo-consistency`):
|
||||
|
||||
```bash
|
||||
python utils/check_dummies.py
|
||||
```
|
||||
|
||||
Update the dummy files if needed (used in `make fix-copies`):
|
||||
|
||||
```bash
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_dummies.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
# Matches from xxx import bla
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
# Matches if not is_xxx_available()
|
||||
_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")
|
||||
|
||||
|
||||
# Template for the dummy objects.
|
||||
DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_CLASS = """
|
||||
class {0}(metaclass=DummyObject):
|
||||
_backends = {1}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, {1})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_backends({0}, {1})
|
||||
"""
|
||||
|
||||
|
||||
def find_backend(line: str) -> Optional[str]:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
Args:
|
||||
line (`str`): A code line in an init file.
|
||||
|
||||
Returns:
|
||||
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
|
||||
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
|
||||
`xxx_and_yyy` for instance).
|
||||
"""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def read_init() -> dict[str, list[str]]:
|
||||
"""
|
||||
Read the init and extract backend-specific objects.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
|
||||
"""
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get to the point we do the actual imports for type checking
|
||||
line_index = 0
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects = {}
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
if backend is not None:
|
||||
while not lines[line_index].startswith(" else:"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
# Single-line imports
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
# Multiple-line imports (with 3 indent level)
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
backend_specific_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
return backend_specific_objects
|
||||
|
||||
|
||||
def create_dummy_object(name: str, backend_name: str) -> str:
|
||||
"""
|
||||
Create the code for a dummy object.
|
||||
|
||||
Args:
|
||||
name (`str`): The name of the object.
|
||||
backend_name (`str`): The name of the backend required for that object.
|
||||
|
||||
Returns:
|
||||
`str`: The code of the dummy object.
|
||||
"""
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return DUMMY_FUNCTION.format(name, backend_name)
|
||||
else:
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files(backend_specific_objects: Optional[dict[str, list[str]]] = None) -> dict[str, str]:
|
||||
"""
|
||||
Create the content of the dummy files.
|
||||
|
||||
Args:
|
||||
backend_specific_objects (`Dict[str, List[str]]`, *optional*):
|
||||
The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
|
||||
`read_init()`.
|
||||
|
||||
Returns:
|
||||
`Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
|
||||
"""
|
||||
if backend_specific_objects is None:
|
||||
backend_specific_objects = read_init()
|
||||
|
||||
dummy_files = {}
|
||||
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = "[" + ", ".join(f'"{b}"' for b in backend.split("_and_")) + "]"
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
dummy_file += "from ..utils import DummyObject, requires_backends\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
||||
return dummy_files
|
||||
|
||||
|
||||
def check_dummies(overwrite: bool = False):
|
||||
"""
|
||||
Check if the dummy files are up to date and maybe `overwrite` with the right content.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, default to `False`):
|
||||
Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
|
||||
when `overwrite=False`.
|
||||
"""
|
||||
dummy_files = create_dummy_files()
|
||||
# For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
dummy_file_paths = {
|
||||
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") for backend in dummy_files
|
||||
}
|
||||
|
||||
actual_dummies = {}
|
||||
for backend, file_path in dummy_file_paths.items():
|
||||
if os.path.isfile(file_path):
|
||||
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_dummies[backend] = f.read()
|
||||
else:
|
||||
actual_dummies[backend] = ""
|
||||
|
||||
# Compare actual with what they should be.
|
||||
for backend in dummy_files:
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
print(
|
||||
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
"__init__ has new objects."
|
||||
)
|
||||
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(dummy_files[backend])
|
||||
else:
|
||||
# Temporary fix to help people identify which objects introduced are not correctly protected.
|
||||
found = False
|
||||
for _actual, _dummy in zip(
|
||||
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
|
||||
):
|
||||
if _actual != _dummy:
|
||||
actual_broken = _actual
|
||||
dummy_broken = _dummy
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
print("A transient error was found with the dummies, please investigate.")
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"
|
||||
f" It is likely the following objects are responsible, see these excerpts: \n"
|
||||
f"---------------------------------- Actual -------------------------------------\n"
|
||||
f" \n {actual_broken} \n"
|
||||
f"---------------------------------- Dummy -------------------------------------\n"
|
||||
f" \n {dummy_broken} \n"
|
||||
"Run `make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_dummies(args.fix_and_overwrite)
|
||||
354
transformers/utils/check_inits.py
Normal file
354
transformers/utils/check_inits.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that checks the custom inits of Transformers are well-defined: Transformers uses init files that delay the
|
||||
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
|
||||
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
|
||||
delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the
|
||||
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. The goal of this
|
||||
script is to check the objects defined in both halves are the same.
|
||||
|
||||
This also checks the main init properly references all submodules, even if it doesn't import anything from them: every
|
||||
submodule should be defined as a key of `_import_structure`, with an empty list as value potentially, or the submodule
|
||||
won't be importable.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/check_inits.py
|
||||
```
|
||||
|
||||
for a check that will error in case of inconsistencies (used by `make repo-consistency`).
|
||||
|
||||
There is no auto-fix possible here sadly :-(
|
||||
"""
|
||||
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Path is set with the intent you should run this script from the root of the repo.
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
# Matches is_xxx_available()
|
||||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
# Catches a one-line _import_struct = {xxx}
|
||||
_re_one_line_import_struct = re.compile(r"^_import_structure\s+=\s+\{([^\}]+)\}")
|
||||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"]
|
||||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]')
|
||||
# Catches a line if not is_foo_available
|
||||
_re_test_backend = re.compile(r"^\s*if\s+not\s+is\_[a-z_]*\_available\(\)")
|
||||
# Catches a line _import_struct["bla"].append("foo")
|
||||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)')
|
||||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"]
|
||||
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]")
|
||||
# Catches a line with an object between quotes and a comma: "MyModel",
|
||||
_re_quote_object = re.compile(r'^\s+"([^"]+)",')
|
||||
# Catches a line with objects between brackets only: ["foo", "bar"],
|
||||
_re_between_brackets = re.compile(r"^\s+\[([^\]]+)\]")
|
||||
# Catches a line with from foo import bar, bla, boo
|
||||
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
# Catches a line with try:
|
||||
_re_try = re.compile(r"^\s*try:")
|
||||
# Catches a line with else:
|
||||
_re_else = re.compile(r"^\s*else:")
|
||||
|
||||
|
||||
def find_backend(line: str) -> Optional[str]:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
Args:
|
||||
line (`str`): A code line of the main init.
|
||||
|
||||
Returns:
|
||||
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
|
||||
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
|
||||
`xxx_and_yyy` for instance).
|
||||
"""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
backends.sort()
|
||||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def parse_init(init_file) -> Optional[tuple[dict[str, list[str]], dict[str, list[str]]]]:
|
||||
"""
|
||||
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
|
||||
defined.
|
||||
|
||||
Args:
|
||||
init_file (`str`): Path to the init file to inspect.
|
||||
|
||||
Returns:
|
||||
`Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]`: A tuple of two dictionaries mapping backends to list of
|
||||
imported objects, one for the `_import_structure` part of the init and one for the `TYPE_CHECKING` part of the
|
||||
init. Returns `None` if the init is not a custom init.
|
||||
"""
|
||||
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get the to `_import_structure` definition.
|
||||
line_index = 0
|
||||
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
|
||||
line_index += 1
|
||||
|
||||
# If this is a traditional init, just return.
|
||||
if line_index >= len(lines):
|
||||
return None
|
||||
|
||||
# First grab the objects without a specific backend in _import_structure
|
||||
objects = []
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING") and find_backend(lines[line_index]) is None:
|
||||
line = lines[line_index]
|
||||
# If we have everything on a single line, let's deal with it.
|
||||
if _re_one_line_import_struct.search(line):
|
||||
content = _re_one_line_import_struct.search(line).groups()[0]
|
||||
imports = re.findall(r"\[([^\]]+)\]", content)
|
||||
for imp in imports:
|
||||
objects.extend([obj[1:-1] for obj in imp.split(", ")])
|
||||
line_index += 1
|
||||
continue
|
||||
single_line_import_search = _re_import_struct_key_value.search(line)
|
||||
if single_line_import_search is not None:
|
||||
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
line_index += 1
|
||||
|
||||
# Those are stored with the key "none".
|
||||
import_dict_objects = {"none": objects}
|
||||
|
||||
# Let's continue with backend-specific objects in _import_structure
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
# If the line is an if not is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
# Check if the backend declaration is inside a try block:
|
||||
if _re_try.search(lines[line_index - 1]) is None:
|
||||
backend = None
|
||||
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Scroll until we hit the else block of try-except-else
|
||||
while _re_else.search(lines[line_index]) is None:
|
||||
line_index += 1
|
||||
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4):
|
||||
line = lines[line_index]
|
||||
if _re_import_struct_add_one.search(line) is not None:
|
||||
objects.append(_re_import_struct_add_one.search(line).groups()[0])
|
||||
elif _re_import_struct_add_many.search(line) is not None:
|
||||
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_between_brackets.search(line) is not None:
|
||||
imports = _re_between_brackets.search(line).groups()[0].split(", ")
|
||||
imports = [obj[1:-1] for obj in imports if len(obj) > 0]
|
||||
objects.extend(imports)
|
||||
elif _re_quote_object.search(line) is not None:
|
||||
objects.append(_re_quote_object.search(line).groups()[0])
|
||||
elif line.startswith(" " * 8 + '"'):
|
||||
objects.append(line[9:-3])
|
||||
elif line.startswith(" " * 12 + '"'):
|
||||
objects.append(line[13:-3])
|
||||
line_index += 1
|
||||
|
||||
import_dict_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend
|
||||
objects = []
|
||||
while (
|
||||
line_index < len(lines)
|
||||
and find_backend(lines[line_index]) is None
|
||||
and not lines[line_index].startswith("else")
|
||||
):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 8):
|
||||
objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects = {"none": objects}
|
||||
|
||||
# Let's continue with backend-specific objects
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
backend = find_backend(lines[line_index])
|
||||
# Check if the backend declaration is inside a try block:
|
||||
if _re_try.search(lines[line_index - 1]) is None:
|
||||
backend = None
|
||||
|
||||
if backend is not None:
|
||||
line_index += 1
|
||||
|
||||
# Scroll until we hit the else block of try-except-else
|
||||
while _re_else.search(lines[line_index]) is None:
|
||||
line_index += 1
|
||||
|
||||
line_index += 1
|
||||
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
type_hint_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
return import_dict_objects, type_hint_objects
|
||||
|
||||
|
||||
def analyze_results(import_dict_objects: dict[str, list[str]], type_hint_objects: dict[str, list[str]]) -> list[str]:
|
||||
"""
|
||||
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
|
||||
|
||||
Args:
|
||||
import_dict_objects (`Dict[str, List[str]]`):
|
||||
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
|
||||
list of imported objects.
|
||||
type_hint_objects (`Dict[str, List[str]]`):
|
||||
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
|
||||
list of imported objects.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of errors corresponding to mismatches.
|
||||
"""
|
||||
|
||||
def find_duplicates(seq):
|
||||
return [k for k, v in collections.Counter(seq).items() if v > 1]
|
||||
|
||||
# If one backend is missing from the other part of the init, error early.
|
||||
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
|
||||
return ["Both sides of the init do not have the same backends!"]
|
||||
|
||||
errors = []
|
||||
# Find all errors.
|
||||
for key in import_dict_objects:
|
||||
# Duplicate imports in any half.
|
||||
duplicate_imports = find_duplicates(import_dict_objects[key])
|
||||
if duplicate_imports:
|
||||
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
|
||||
duplicate_type_hints = find_duplicates(type_hint_objects[key])
|
||||
if duplicate_type_hints:
|
||||
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
|
||||
|
||||
# Missing imports in either part of the init.
|
||||
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
|
||||
name = "base imports" if key == "none" else f"{key} backend"
|
||||
errors.append(f"Differences for {name}:")
|
||||
for a in type_hint_objects[key]:
|
||||
if a not in import_dict_objects[key]:
|
||||
errors.append(f" {a} in TYPE_HINT but not in _import_structure.")
|
||||
for a in import_dict_objects[key]:
|
||||
if a not in type_hint_objects[key]:
|
||||
errors.append(f" {a} in _import_structure but not in TYPE_HINT.")
|
||||
return errors
|
||||
|
||||
|
||||
def get_transformers_submodules() -> list[str]:
|
||||
"""
|
||||
Returns the list of Transformers submodules.
|
||||
"""
|
||||
submodules = []
|
||||
for path, directories, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
for folder in directories:
|
||||
# Ignore private modules
|
||||
if folder.startswith("_"):
|
||||
directories.remove(folder)
|
||||
continue
|
||||
# Ignore leftovers from branches (empty folders apart from pycache)
|
||||
if len(list((Path(path) / folder).glob("*.py"))) == 0:
|
||||
continue
|
||||
short_path = str((Path(path) / folder).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(os.path.sep, ".")
|
||||
submodules.append(submodule)
|
||||
for fname in files:
|
||||
if fname == "__init__.py":
|
||||
continue
|
||||
short_path = str((Path(path) / fname).relative_to(PATH_TO_TRANSFORMERS))
|
||||
submodule = short_path.replace(".py", "").replace(os.path.sep, ".")
|
||||
if len(submodule.split(".")) == 1:
|
||||
submodules.append(submodule)
|
||||
return submodules
|
||||
|
||||
|
||||
IGNORE_SUBMODULES = [
|
||||
"convert_pytorch_checkpoint_to_tf2",
|
||||
"models.esm.openfold_utils",
|
||||
"modeling_attn_mask_utils",
|
||||
"safetensors_conversion",
|
||||
"modeling_gguf_pytorch_utils",
|
||||
"kernels.falcon_mamba",
|
||||
"kernels",
|
||||
]
|
||||
|
||||
|
||||
def check_submodules():
|
||||
"""
|
||||
Check all submodules of Transformers are properly registered in the main init. Error otherwise.
|
||||
"""
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
||||
|
||||
import_structure_keys = set(transformers._import_structure.keys())
|
||||
# This contains all the base keys of the _import_structure object defined in the init, but if the user is missing
|
||||
# some optional dependencies, they may not have all of them. Thus we read the init to read all additions and
|
||||
# (potentiall re-) add them.
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r") as f:
|
||||
init_content = f.read()
|
||||
import_structure_keys.update(set(re.findall(r"import_structure\[\"([^\"]*)\"\]", init_content)))
|
||||
|
||||
module_not_registered = [
|
||||
module
|
||||
for module in get_transformers_submodules()
|
||||
if module not in IGNORE_SUBMODULES and module not in import_structure_keys
|
||||
]
|
||||
|
||||
if len(module_not_registered) > 0:
|
||||
list_of_modules = "\n".join(f"- {module}" for module in module_not_registered)
|
||||
raise ValueError(
|
||||
"The following submodules are not properly registered in the main init of Transformers:\n"
|
||||
f"{list_of_modules}\n"
|
||||
"Make sure they appear somewhere in the keys of `_import_structure` with an empty list as value."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This entire files needs an overhaul
|
||||
pass
|
||||
59
transformers/utils/check_model_tester.py
Normal file
59
transformers/utils/check_model_tester.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
from get_test_info import get_tester_classes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
failures = []
|
||||
|
||||
pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")
|
||||
test_files = glob.glob(pattern)
|
||||
|
||||
for test_file in test_files:
|
||||
tester_classes = get_tester_classes(test_file)
|
||||
for tester_class in tester_classes:
|
||||
# A few tester classes don't have `parent` parameter in `__init__`.
|
||||
# TODO: deal this better
|
||||
try:
|
||||
tester = tester_class(parent=None)
|
||||
except Exception:
|
||||
continue
|
||||
if hasattr(tester, "get_config"):
|
||||
config = tester.get_config()
|
||||
for k, v in config.to_dict().items():
|
||||
if isinstance(v, int):
|
||||
target = None
|
||||
if k in ["vocab_size"]:
|
||||
target = 100
|
||||
elif k in ["max_position_embeddings"]:
|
||||
target = 128
|
||||
elif k in ["hidden_size", "d_model"]:
|
||||
target = 40
|
||||
elif k == ["num_layers", "num_hidden_layers", "num_encoder_layers", "num_decoder_layers"]:
|
||||
target = 5
|
||||
if target is not None and v > target:
|
||||
failures.append(
|
||||
f"{tester_class.__name__} will produce a `config` of type `{config.__class__.__name__}`"
|
||||
f' with config["{k}"] = {v} which is too large for testing! Set its value to be smaller'
|
||||
f" than {target}."
|
||||
)
|
||||
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
232
transformers/utils/check_modular_conversion.py
Normal file
232
transformers/utils/check_modular_conversion.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import argparse
|
||||
import difflib
|
||||
import glob
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from functools import partial
|
||||
from io import StringIO
|
||||
|
||||
from create_dependency_mapping import find_priority_list
|
||||
|
||||
# Console for rich printing
|
||||
from modular_model_converter import convert_modular_file
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
console = Console()
|
||||
|
||||
BACKUP_EXT = ".modular_backup"
|
||||
|
||||
|
||||
def process_file(
|
||||
modular_file_path,
|
||||
generated_modeling_content,
|
||||
file_type="modeling_",
|
||||
show_diff=True,
|
||||
):
|
||||
file_name_prefix = file_type.split("*")[0]
|
||||
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else ""
|
||||
file_path = modular_file_path.replace("modular_", f"{file_name_prefix}_").replace(".py", f"{file_name_suffix}.py")
|
||||
# Read the actual modeling file
|
||||
with open(file_path, "r", encoding="utf-8") as modeling_file:
|
||||
content = modeling_file.read()
|
||||
output_buffer = StringIO(generated_modeling_content[file_type])
|
||||
output_buffer.seek(0)
|
||||
output_content = output_buffer.read()
|
||||
diff = difflib.unified_diff(
|
||||
output_content.splitlines(),
|
||||
content.splitlines(),
|
||||
fromfile=f"{file_path}_generated",
|
||||
tofile=f"{file_path}",
|
||||
lineterm="",
|
||||
)
|
||||
diff_list = list(diff)
|
||||
# Check for differences
|
||||
if diff_list:
|
||||
# first save the copy of the original file, to be able to restore it later
|
||||
if os.path.exists(file_path):
|
||||
shutil.copy(file_path, file_path + BACKUP_EXT)
|
||||
# we always save the generated content, to be able to update dependant files
|
||||
with open(file_path, "w", encoding="utf-8", newline="\n") as modeling_file:
|
||||
modeling_file.write(generated_modeling_content[file_type])
|
||||
console.print(f"[bold blue]Overwritten {file_path} with the generated content.[/bold blue]")
|
||||
if show_diff:
|
||||
console.print(f"\n[bold red]Differences found between the generated code and {file_path}:[/bold red]\n")
|
||||
diff_text = "\n".join(diff_list)
|
||||
syntax = Syntax(diff_text, "diff", theme="ansi_dark", line_numbers=True)
|
||||
console.print(syntax)
|
||||
return 1
|
||||
else:
|
||||
console.print(f"[bold green]No differences found for {file_path}.[/bold green]")
|
||||
return 0
|
||||
|
||||
|
||||
def compare_files(modular_file_path, show_diff=True):
|
||||
# Generate the expected modeling content
|
||||
generated_modeling_content = convert_modular_file(modular_file_path)
|
||||
diff = 0
|
||||
for file_type in generated_modeling_content:
|
||||
diff += process_file(modular_file_path, generated_modeling_content, file_type, show_diff)
|
||||
return diff
|
||||
|
||||
|
||||
def get_models_in_diff():
|
||||
"""
|
||||
Finds all models that have been modified in the diff.
|
||||
|
||||
Returns:
|
||||
A set containing the names of the models that have been modified (e.g. {'llama', 'whisper'}).
|
||||
"""
|
||||
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||
modified_files = (
|
||||
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split())
|
||||
.decode("utf-8")
|
||||
.split()
|
||||
)
|
||||
|
||||
# Matches both modelling files and tests
|
||||
relevant_modified_files = [x for x in modified_files if "/models/" in x and x.endswith(".py")]
|
||||
model_names = set()
|
||||
for file_path in relevant_modified_files:
|
||||
model_name = file_path.split("/")[-2]
|
||||
model_names.add(model_name)
|
||||
return model_names
|
||||
|
||||
|
||||
def guaranteed_no_diff(modular_file_path, dependencies, models_in_diff):
|
||||
"""
|
||||
Returns whether it is guaranteed to have no differences between the modular file and the modeling file.
|
||||
|
||||
Model is in the diff -> not guaranteed to have no differences
|
||||
Dependency is in the diff -> not guaranteed to have no differences
|
||||
Otherwise -> guaranteed to have no differences
|
||||
|
||||
Args:
|
||||
modular_file_path: The path to the modular file.
|
||||
dependencies: A dictionary containing the dependencies of each modular file.
|
||||
models_in_diff: A set containing the names of the models that have been modified.
|
||||
|
||||
Returns:
|
||||
A boolean indicating whether the model (code and tests) is guaranteed to have no differences.
|
||||
"""
|
||||
model_name = modular_file_path.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
if model_name in models_in_diff:
|
||||
return False
|
||||
for dep in dependencies[modular_file_path]:
|
||||
# two possible patterns: `transformers.models.model_name.(...)` or `model_name.(...)`
|
||||
dependency_model_name = dep.split(".")[-2]
|
||||
if dependency_model_name in models_in_diff:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Compare modular_xxx.py files with modeling_xxx.py files.")
|
||||
parser.add_argument(
|
||||
"--files", default=["all"], type=str, nargs="+", help="List of modular_xxx.py files to compare."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fix_and_overwrite", action="store_true", help="Overwrite the modeling_xxx.py file if differences are found."
|
||||
)
|
||||
parser.add_argument("--check_all", action="store_true", help="Check all files, not just the ones in the diff.")
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="The number of workers to run. Default is -1, which means the number of CPU cores.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.files == ["all"]:
|
||||
args.files = glob.glob("src/transformers/models/**/modular_*.py", recursive=True)
|
||||
|
||||
if args.num_workers == -1:
|
||||
args.num_workers = multiprocessing.cpu_count()
|
||||
|
||||
# Assuming there is a topological sort on the dependency mapping: if the file being checked and its dependencies
|
||||
# are not in the diff, then there it is guaranteed to have no differences. If no models are in the diff, then this
|
||||
# script will do nothing.
|
||||
current_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip()
|
||||
if current_branch == "main":
|
||||
console.print(
|
||||
"[bold red]You are developing on the main branch. We cannot identify the list of changed files and will have to check all files. This may take a while.[/bold red]"
|
||||
)
|
||||
models_in_diff = {file_path.split("/")[-2] for file_path in args.files}
|
||||
else:
|
||||
models_in_diff = get_models_in_diff()
|
||||
if not models_in_diff and not args.check_all:
|
||||
console.print(
|
||||
"[bold green]No models files or model tests in the diff, skipping modular checks[/bold green]"
|
||||
)
|
||||
exit(0)
|
||||
|
||||
skipped_models = set()
|
||||
non_matching_files = []
|
||||
ordered_files, dependencies = find_priority_list(args.files)
|
||||
flat_ordered_files = [item for sublist in ordered_files for item in sublist]
|
||||
|
||||
# ordered_files is a *sorted* list of lists of filepaths
|
||||
# - files from the first list do NOT depend on other files
|
||||
# - files in the second list depend on files from the first list
|
||||
# - files in the third list depend on files from the second and (optionally) the first list
|
||||
# - ... and so on
|
||||
# files (models) within the same list are *independent* of each other;
|
||||
# we start applying modular conversion to each list in parallel, starting from the first list
|
||||
|
||||
console.print(f"[bold yellow]Number of dependency levels: {len(ordered_files)}[/bold yellow]")
|
||||
console.print(f"[bold yellow]Files per level: {tuple([len(x) for x in ordered_files])}[/bold yellow]")
|
||||
|
||||
try:
|
||||
for dependency_level_files in ordered_files:
|
||||
# Filter files guaranteed no diff
|
||||
files_to_check = []
|
||||
for file_path in dependency_level_files:
|
||||
if not args.check_all and guaranteed_no_diff(file_path, dependencies, models_in_diff):
|
||||
skipped_models.add(file_path.split("/")[-2]) # save model folder name
|
||||
else:
|
||||
files_to_check.append(file_path)
|
||||
|
||||
if not files_to_check:
|
||||
continue
|
||||
|
||||
# Process files with diff
|
||||
num_workers = min(args.num_workers, len(files_to_check))
|
||||
with multiprocessing.Pool(num_workers) as p:
|
||||
is_changed_flags = p.map(
|
||||
partial(compare_files, show_diff=not args.fix_and_overwrite),
|
||||
files_to_check,
|
||||
)
|
||||
|
||||
# Collect changed files and their original paths
|
||||
for is_changed, file_path in zip(is_changed_flags, files_to_check):
|
||||
if is_changed:
|
||||
non_matching_files.append(file_path)
|
||||
|
||||
# Update changed models, after each round of conversions
|
||||
# (save model folder name)
|
||||
models_in_diff.add(file_path.split("/")[-2])
|
||||
|
||||
finally:
|
||||
# Restore overwritten files by modular (if needed)
|
||||
backup_files = glob.glob("**/*" + BACKUP_EXT, recursive=True)
|
||||
for backup_file_path in backup_files:
|
||||
overwritten_path = backup_file_path.replace(BACKUP_EXT, "")
|
||||
if not args.fix_and_overwrite and os.path.exists(overwritten_path):
|
||||
shutil.copy(backup_file_path, overwritten_path)
|
||||
os.remove(backup_file_path)
|
||||
|
||||
if non_matching_files and not args.fix_and_overwrite:
|
||||
diff_models = set(file_path.split("/")[-2] for file_path in non_matching_files) # noqa
|
||||
models_str = "\n - " + "\n - ".join(sorted(diff_models))
|
||||
raise ValueError(f"Some diff and their modeling code did not match. Models in diff:{models_str}")
|
||||
|
||||
if skipped_models:
|
||||
console.print(
|
||||
f"[bold green]Skipped {len(skipped_models)} models and their dependencies that are not in the diff: "
|
||||
f"{', '.join(sorted(skipped_models))}[/bold green]"
|
||||
)
|
||||
93
transformers/utils/check_pipeline_typing.py
Normal file
93
transformers/utils/check_pipeline_typing.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import re
|
||||
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Pipeline
|
||||
|
||||
|
||||
HEADER = """
|
||||
# fmt: off
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# The part of the file below was automatically generated from the code.
|
||||
# Do NOT edit this part of the file manually as any edits will be overwritten by the generation
|
||||
# of the file. If any change should be done, please apply the changes to the `pipeline` function
|
||||
# below and run `python utils/check_pipeline_typing.py --fix_and_overwrite` to update the file.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from typing import Literal, overload
|
||||
|
||||
|
||||
"""
|
||||
|
||||
FOOTER = """
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# The part of the file above was automatically generated from the code.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# fmt: on
|
||||
"""
|
||||
|
||||
TASK_PATTERN = "task: Optional[str] = None"
|
||||
|
||||
|
||||
def main(pipeline_file_path: str, fix_and_overwrite: bool = False):
|
||||
with open(pipeline_file_path, "r") as file:
|
||||
content = file.read()
|
||||
|
||||
# extract generated code in between <generated-code> and </generated-code>
|
||||
current_generated_code = re.search(r"# <generated-code>(.*)# </generated-code>", content, re.DOTALL).group(1)
|
||||
content_without_generated_code = content.replace(current_generated_code, "")
|
||||
|
||||
# extract pipeline signature in between `def pipeline` and `-> Pipeline`
|
||||
pipeline_signature = re.search(r"def pipeline(.*) -> Pipeline:", content_without_generated_code, re.DOTALL).group(
|
||||
1
|
||||
)
|
||||
pipeline_signature = pipeline_signature.replace("(\n ", "(") # start of the signature
|
||||
pipeline_signature = pipeline_signature.replace(",\n ", ", ") # intermediate arguments
|
||||
pipeline_signature = pipeline_signature.replace(",\n)", ")") # end of the signature
|
||||
|
||||
# collect and sort available pipelines
|
||||
pipelines = [(f'"{task}"', task_info["impl"]) for task, task_info in SUPPORTED_TASKS.items()]
|
||||
pipelines = sorted(pipelines, key=lambda x: x[0])
|
||||
pipelines.insert(0, (None, Pipeline))
|
||||
|
||||
# generate new `pipeline` signatures
|
||||
new_generated_code = ""
|
||||
for task, pipeline_class in pipelines:
|
||||
if TASK_PATTERN not in pipeline_signature:
|
||||
raise ValueError(f"Can't find `{TASK_PATTERN}` in pipeline signature: {pipeline_signature}")
|
||||
pipeline_type = pipeline_class if isinstance(pipeline_class, str) else pipeline_class.__name__
|
||||
new_pipeline_signature = pipeline_signature.replace(TASK_PATTERN, f"task: Literal[{task}]")
|
||||
new_generated_code += f"@overload\ndef pipeline{new_pipeline_signature} -> {pipeline_type}: ...\n"
|
||||
|
||||
new_generated_code = HEADER + new_generated_code + FOOTER
|
||||
new_generated_code = new_generated_code.rstrip("\n") + "\n"
|
||||
|
||||
if new_generated_code != current_generated_code and fix_and_overwrite:
|
||||
print(f"Updating {pipeline_file_path}...")
|
||||
wrapped_current_generated_code = "# <generated-code>" + current_generated_code + "# </generated-code>"
|
||||
wrapped_new_generated_code = "# <generated-code>" + new_generated_code + "# </generated-code>"
|
||||
content = content.replace(wrapped_current_generated_code, wrapped_new_generated_code)
|
||||
|
||||
# write content to file
|
||||
with open(pipeline_file_path, "w") as file:
|
||||
file.write(content)
|
||||
|
||||
elif new_generated_code != current_generated_code and not fix_and_overwrite:
|
||||
message = (
|
||||
f"Found inconsistencies in {pipeline_file_path}. "
|
||||
"Run `python utils/check_pipeline_typing.py --fix_and_overwrite` to fix them."
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
parser.add_argument(
|
||||
"--pipeline_file_path",
|
||||
type=str,
|
||||
default="src/transformers/pipelines/__init__.py",
|
||||
help="Path to the pipeline file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.pipeline_file_path, args.fix_and_overwrite)
|
||||
1191
transformers/utils/check_repo.py
Normal file
1191
transformers/utils/check_repo.py
Normal file
File diff suppressed because it is too large
Load Diff
57
transformers/utils/check_self_hosted_runner.py
Normal file
57
transformers/utils/check_self_hosted_runner.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
|
||||
def get_runner_status(target_runners, token):
|
||||
offline_runners = []
|
||||
|
||||
cmd = [
|
||||
"curl",
|
||||
"-H",
|
||||
"Accept: application/vnd.github+json",
|
||||
"-H",
|
||||
f"Authorization: Bearer {token}",
|
||||
"https://api.github.com/repos/huggingface/transformers/actions/runners",
|
||||
]
|
||||
|
||||
output = subprocess.run(cmd, check=False, shell=True, stdout=subprocess.PIPE)
|
||||
o = output.stdout.decode("utf-8")
|
||||
status = json.loads(o)
|
||||
|
||||
runners = status["runners"]
|
||||
for runner in runners:
|
||||
if runner["name"] in target_runners:
|
||||
if runner["status"] == "offline":
|
||||
offline_runners.append(runner)
|
||||
|
||||
# save the result so we can report them on Slack
|
||||
with open("offline_runners.txt", "w") as fp:
|
||||
fp.write(json.dumps(offline_runners))
|
||||
|
||||
if len(offline_runners) > 0:
|
||||
failed = "\n".join([x["name"] for x in offline_runners])
|
||||
raise ValueError(f"The following runners are offline:\n{failed}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def list_str(values):
|
||||
return values.split(",")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument(
|
||||
"--target_runners",
|
||||
default=None,
|
||||
type=list_str,
|
||||
required=True,
|
||||
help="Comma-separated list of runners to check status.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token", default=None, type=str, required=True, help="A token that has actions:read permission."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
get_runner_status(args.target_runners, args.token)
|
||||
217
transformers/utils/collated_reports.py
Normal file
217
transformers/utils/collated_reports.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
DEFAULT_GPU_NAMES = ["mi300", "mi325", "mi355", "h100", "a10"]
|
||||
|
||||
|
||||
def simplify_gpu_name(gpu_name: str, simplified_names: list[str]) -> str:
|
||||
matches = []
|
||||
for simplified_name in simplified_names:
|
||||
if simplified_name in gpu_name:
|
||||
matches.append(simplified_name)
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
return gpu_name
|
||||
|
||||
|
||||
def parse_short_summary_line(line: str) -> tuple[str | None, int]:
|
||||
if line.startswith("PASSED"):
|
||||
return "passed", 1
|
||||
if line.startswith("FAILED"):
|
||||
return "failed", 1
|
||||
if line.startswith("SKIPPED"):
|
||||
line = line.split("[", maxsplit=1)[1]
|
||||
line = line.split("]", maxsplit=1)[0]
|
||||
return "skipped", int(line)
|
||||
if line.startswith("ERROR"):
|
||||
return "error", 1
|
||||
return None, 0
|
||||
|
||||
|
||||
def validate_path(p: str) -> Path:
|
||||
# Validate path and apply glob pattern if provided
|
||||
path = Path(p)
|
||||
assert path.is_dir(), f"Path {path} is not a directory"
|
||||
return path
|
||||
|
||||
|
||||
def get_gpu_name(gpu_name: str | None) -> str:
|
||||
# Get GPU name if available
|
||||
if gpu_name is None:
|
||||
try:
|
||||
import torch
|
||||
|
||||
gpu_name = torch.cuda.get_device_name()
|
||||
except Exception as e:
|
||||
print(f"Failed to get GPU name with {e}")
|
||||
gpu_name = "unknown"
|
||||
else:
|
||||
gpu_name = gpu_name.replace(" ", "_").lower()
|
||||
gpu_name = simplify_gpu_name(gpu_name, DEFAULT_GPU_NAMES)
|
||||
|
||||
return gpu_name
|
||||
|
||||
|
||||
def get_commit_hash(commit_hash: str | None) -> str:
|
||||
# Get commit hash if available
|
||||
if commit_hash is None:
|
||||
try:
|
||||
commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
||||
except Exception as e:
|
||||
print(f"Failed to get commit hash with {e}")
|
||||
commit_hash = "unknown"
|
||||
|
||||
return commit_hash[:7]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
path: Path
|
||||
machine_type: str
|
||||
gpu_name: str
|
||||
commit_hash: str
|
||||
job: str | None
|
||||
report_repo_id: str | None
|
||||
|
||||
|
||||
def get_arguments(args: argparse.Namespace) -> Args:
|
||||
path = validate_path(args.path)
|
||||
machine_type = args.machine_type
|
||||
gpu_name = get_gpu_name(args.gpu_name)
|
||||
commit_hash = get_commit_hash(args.commit_hash)
|
||||
job = args.job
|
||||
report_repo_id = args.report_repo_id
|
||||
return Args(path, machine_type, gpu_name, commit_hash, job, report_repo_id)
|
||||
|
||||
|
||||
def upload_collated_report(job: str, report_repo_id: str, filename: str):
|
||||
# Alternatively we can check for the existence of the collated_reports file and upload in notification_service.py
|
||||
import os
|
||||
|
||||
from get_previous_daily_ci import get_last_daily_ci_run
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
|
||||
# if it is not a scheduled run, upload the reports to a subfolder under `report_repo_folder`
|
||||
report_repo_subfolder = ""
|
||||
if os.getenv("GITHUB_EVENT_NAME") != "schedule":
|
||||
report_repo_subfolder = f"{os.getenv('GITHUB_RUN_NUMBER')}-{os.getenv('GITHUB_RUN_ID')}"
|
||||
report_repo_subfolder = f"runs/{report_repo_subfolder}"
|
||||
|
||||
workflow_run = get_last_daily_ci_run(
|
||||
token=os.environ["ACCESS_REPO_INFO_TOKEN"], workflow_run_id=os.getenv("GITHUB_RUN_ID")
|
||||
)
|
||||
workflow_run_created_time = workflow_run["created_at"]
|
||||
report_repo_folder = workflow_run_created_time.split("T")[0]
|
||||
|
||||
if report_repo_subfolder:
|
||||
report_repo_folder = f"{report_repo_folder}/{report_repo_subfolder}"
|
||||
|
||||
api.upload_file(
|
||||
path_or_fileobj=f"{filename}",
|
||||
path_in_repo=f"{report_repo_folder}/ci_results_{job}/{filename}",
|
||||
repo_id=report_repo_id,
|
||||
repo_type="dataset",
|
||||
token=os.getenv("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN"),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Post process models test reports.")
|
||||
parser.add_argument("--path", "-p", help="Path to the reports folder")
|
||||
parser.add_argument(
|
||||
"--machine-type", "-m", help="Process single or multi GPU results", choices=["single-gpu", "multi-gpu"]
|
||||
)
|
||||
parser.add_argument("--gpu-name", "-g", help="GPU name", default=None)
|
||||
parser.add_argument("--commit-hash", "-c", help="Commit hash", default=None)
|
||||
parser.add_argument("--job", "-j", help="Optional job name required for uploading reports", default=None)
|
||||
parser.add_argument(
|
||||
"--report-repo-id", "-r", help="Optional report repository ID required for uploading reports", default=None
|
||||
)
|
||||
args = get_arguments(parser.parse_args())
|
||||
|
||||
# Initialize accumulators for collated report
|
||||
total_status_count = {
|
||||
"passed": 0,
|
||||
"failed": 0,
|
||||
"skipped": 0,
|
||||
"error": 0,
|
||||
None: 0,
|
||||
}
|
||||
collated_report_buffer = []
|
||||
|
||||
path = args.path
|
||||
machine_type = args.machine_type
|
||||
gpu_name = args.gpu_name
|
||||
commit_hash = args.commit_hash
|
||||
job = args.job
|
||||
report_repo_id = args.report_repo_id
|
||||
|
||||
# Loop through model directories and create collated reports
|
||||
for model_dir in sorted(path.iterdir()):
|
||||
if not model_dir.name.startswith(machine_type):
|
||||
continue
|
||||
|
||||
# Create a new entry for the model
|
||||
model_name = model_dir.name.split("models_")[-1].removesuffix("_test_reports")
|
||||
report = {"model": model_name, "results": []}
|
||||
results = []
|
||||
|
||||
# Read short summary
|
||||
with open(model_dir / "summary_short.txt", "r") as f:
|
||||
short_summary_lines = f.readlines()
|
||||
|
||||
# Parse short summary
|
||||
for line in short_summary_lines[1:]:
|
||||
status, count = parse_short_summary_line(line)
|
||||
total_status_count[status] += count
|
||||
if status:
|
||||
result = {
|
||||
"status": status,
|
||||
"test": line.split(status.upper(), maxsplit=1)[1].strip(),
|
||||
"count": count,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
# Add short summaries to report
|
||||
report["results"] = results
|
||||
|
||||
collated_report_buffer.append(report)
|
||||
|
||||
filename = f"collated_reports_{machine_type}_{commit_hash}.json"
|
||||
# Write collated report
|
||||
with open(filename, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"gpu_name": gpu_name,
|
||||
"machine_type": machine_type,
|
||||
"commit_hash": commit_hash,
|
||||
"total_status_count": total_status_count,
|
||||
"results": collated_report_buffer,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
# Upload collated report
|
||||
if job and report_repo_id:
|
||||
upload_collated_report(job, report_repo_id, filename)
|
||||
91
transformers/utils/compare_test_runs.py
Normal file
91
transformers/utils/compare_test_runs.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def normalize_test_line(line):
|
||||
line = line.strip()
|
||||
|
||||
# Normalize SKIPPED/XFAIL/etc with path:line and reason
|
||||
match = re.match(r"^(SKIPPED|XFAIL|XPASS|EXPECTEDFAIL)\s+\[?\d*\]?\s*(\S+:\d+)", line)
|
||||
if match:
|
||||
status, location = match.groups()
|
||||
return f"{status} {location}"
|
||||
|
||||
# Normalize ERROR/FAILED lines with optional message
|
||||
if line.startswith("ERROR") or line.startswith("FAILED"):
|
||||
return re.split(r"\s+-\s+", line)[0].strip()
|
||||
|
||||
return line
|
||||
|
||||
|
||||
def parse_summary_file(file_path):
|
||||
test_set = set()
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
in_summary = False
|
||||
for line in f:
|
||||
if line.strip().startswith("==="):
|
||||
in_summary = not in_summary
|
||||
continue
|
||||
if in_summary:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
normalized = normalize_test_line(stripped)
|
||||
test_set.add(normalized)
|
||||
return test_set
|
||||
|
||||
|
||||
def compare_job_sets(job_set1, job_set2):
|
||||
all_job_names = sorted(set(job_set1) | set(job_set2))
|
||||
report_lines = []
|
||||
|
||||
for job_name in all_job_names:
|
||||
file1 = job_set1.get(job_name)
|
||||
file2 = job_set2.get(job_name)
|
||||
|
||||
tests1 = parse_summary_file(file1) if file1 else set()
|
||||
tests2 = parse_summary_file(file2) if file2 else set()
|
||||
|
||||
added = tests2 - tests1
|
||||
removed = tests1 - tests2
|
||||
|
||||
if added or removed:
|
||||
report_lines.append(f"=== Diff for job: {job_name} ===")
|
||||
if removed:
|
||||
report_lines.append("--- Absent in current run:")
|
||||
for test in sorted(removed):
|
||||
report_lines.append(f" - {test}")
|
||||
if added:
|
||||
report_lines.append("+++ Appeared in current run:")
|
||||
for test in sorted(added):
|
||||
report_lines.append(f" + {test}")
|
||||
report_lines.append("") # blank line
|
||||
|
||||
return "\n".join(report_lines) if report_lines else "No differences found."
|
||||
|
||||
|
||||
# Example usage:
|
||||
# job_set_1 = {
|
||||
# "albert": "prev/multi-gpu_run_models_gpu_models/albert_test_reports/summary_short.txt",
|
||||
# "bloom": "prev/multi-gpu_run_models_gpu_models/bloom_test_reports/summary_short.txt",
|
||||
# }
|
||||
|
||||
# job_set_2 = {
|
||||
# "albert": "curr/multi-gpu_run_models_gpu_models/albert_test_reports/summary_short.txt",
|
||||
# "bloom": "curr/multi-gpu_run_models_gpu_models/bloom_test_reports/summary_short.txt",
|
||||
# }
|
||||
|
||||
# report = compare_job_sets(job_set_1, job_set_2)
|
||||
# print(report)
|
||||
113
transformers/utils/create_dependency_mapping.py
Normal file
113
transformers/utils/create_dependency_mapping.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import ast
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
# Function to perform topological sorting
|
||||
def topological_sort(dependencies: dict) -> list[list[str]]:
|
||||
"""Given the dependencies graph, construct a sorted list of list of modular files.
|
||||
|
||||
Examples:
|
||||
|
||||
The returned list of lists might be:
|
||||
[
|
||||
["../modular_mistral.py", "../modular_gemma.py"], # level 0
|
||||
["../modular_llama4.py", "../modular_gemma2.py"], # level 1
|
||||
["../modular_glm4.py"], # level 2
|
||||
]
|
||||
which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2
|
||||
depend on the models in the first list, and glm4 depends on the models in the second and (optionally) in the first list.
|
||||
"""
|
||||
|
||||
# Nodes are the name of the models to convert (we only add those to the graph)
|
||||
nodes = {node.rsplit("modular_", 1)[1].replace(".py", "") for node in dependencies}
|
||||
# This will be a graph from models to convert, to models to convert that should be converted before (as they are a dependency)
|
||||
graph = {}
|
||||
name_mapping = {}
|
||||
for node, deps in dependencies.items():
|
||||
node_name = node.rsplit("modular_", 1)[1].replace(".py", "")
|
||||
dep_names = {dep.split(".")[-2] for dep in deps}
|
||||
dependencies = {dep for dep in dep_names if dep in nodes and dep != node_name}
|
||||
graph[node_name] = dependencies
|
||||
name_mapping[node_name] = node
|
||||
|
||||
sorting_list = []
|
||||
while len(graph) > 0:
|
||||
# Find the nodes with 0 out-degree
|
||||
leaf_nodes = {node for node in graph if len(graph[node]) == 0}
|
||||
# Add them to the list as next level
|
||||
sorting_list.append([name_mapping[node] for node in leaf_nodes])
|
||||
# Remove the leaves from the graph (and from the deps of other nodes)
|
||||
graph = {node: deps - leaf_nodes for node, deps in graph.items() if node not in leaf_nodes}
|
||||
|
||||
return sorting_list
|
||||
|
||||
|
||||
# All the model file types that may be imported in modular files
|
||||
ALL_FILE_TYPES = (
|
||||
"modeling",
|
||||
"configuration",
|
||||
"tokenization",
|
||||
"processing",
|
||||
"image_processing",
|
||||
"video_processing",
|
||||
"feature_extraction",
|
||||
)
|
||||
|
||||
|
||||
def is_model_import(module: str) -> bool:
|
||||
"""Check whether `module` is a model import or not."""
|
||||
patterns = "|".join(ALL_FILE_TYPES)
|
||||
regex = rf"(\w+)\.(?:{patterns})_(\w+)"
|
||||
match_object = re.search(regex, module)
|
||||
if match_object is not None:
|
||||
model_name = match_object.group(1)
|
||||
if model_name in match_object.group(2) and model_name != "auto":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_model_imports_from_file(file_path):
|
||||
"""From a python file `file_path`, extract the model-specific imports (the imports related to any model file in
|
||||
Transformers)"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
tree = ast.parse(file.read(), filename=file_path)
|
||||
imports = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
if is_model_import(node.module):
|
||||
imports.add(node.module)
|
||||
return imports
|
||||
|
||||
|
||||
def find_priority_list(modular_files: list[str]) -> tuple[list[list[str]], dict[str, set]]:
|
||||
"""
|
||||
Given a list of modular files, sorts them by topological order. Modular models that DON'T depend on other modular
|
||||
models will be lower in the topological order.
|
||||
|
||||
Args:
|
||||
modular_files (`list[str]`):
|
||||
List of paths to the modular files.
|
||||
|
||||
Returns:
|
||||
A tuple `ordered_files` and `dependencies`.
|
||||
|
||||
`ordered_file` is a list of lists consisting of the models at each level of the dependency graph. For example,
|
||||
it might be:
|
||||
[
|
||||
["../modular_mistral.py", "../modular_gemma.py"], # level 0
|
||||
["../modular_llama4.py", "../modular_gemma2.py"], # level 1
|
||||
["../modular_glm4.py"], # level 2
|
||||
]
|
||||
which means mistral and gemma do not depend on any other modular models, while llama4 and gemma2 depend on the
|
||||
models in the first list, and glm4 depends on the models in the second and (optionally) in the first list.
|
||||
|
||||
`dependencies` is a dictionary mapping each modular file to the models on which it relies (the models that are
|
||||
imported in order to use inheritance).
|
||||
"""
|
||||
dependencies = defaultdict(set)
|
||||
for file_path in modular_files:
|
||||
dependencies[file_path].update(extract_model_imports_from_file(file_path))
|
||||
ordered_files = topological_sort(dependencies)
|
||||
return ordered_files, dependencies
|
||||
1490
transformers/utils/create_dummy_models.py
Normal file
1490
transformers/utils/create_dummy_models.py
Normal file
File diff suppressed because it is too large
Load Diff
330
transformers/utils/custom_init_isort.py
Normal file
330
transformers/utils/custom_init_isort.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that sorts the imports in the custom inits of Transformers. Transformers uses init files that delay the
|
||||
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
|
||||
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
|
||||
delayed imports have two halves: one defining a dictionary `_import_structure` which maps modules to the name of the
|
||||
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. `isort` or `ruff`
|
||||
properly sort the second half which looks like traditionl imports, the goal of this script is to sort the first half.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/custom_init_isort.py
|
||||
```
|
||||
|
||||
which will auto-sort the imports (used in `make style`).
|
||||
|
||||
For a check only (as used in `make quality`) run:
|
||||
|
||||
```bash
|
||||
python utils/custom_init_isort.py --check_only
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
# Path is defined with the intent you should run this script from the root of the repo.
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
# Pattern that looks at the indentation in a line.
|
||||
_re_indent = re.compile(r"^(\s*)\S")
|
||||
# Pattern that matches `"key":" and puts `key` in group 0.
|
||||
_re_direct_key = re.compile(r'^\s*"([^"]+)":')
|
||||
# Pattern that matches `_import_structure["key"]` and puts `key` in group 0.
|
||||
_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]')
|
||||
# Pattern that matches `"key",` and puts `key` in group 0.
|
||||
_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$')
|
||||
# Pattern that matches any `[stuff]` and puts `stuff` in group 0.
|
||||
_re_bracket_content = re.compile(r"\[([^\]]+)\]")
|
||||
|
||||
|
||||
def get_indent(line: str) -> str:
|
||||
"""Returns the indent in given line (as string)."""
|
||||
search = _re_indent.search(line)
|
||||
return "" if search is None else search.groups()[0]
|
||||
|
||||
|
||||
def split_code_in_indented_blocks(
|
||||
code: str, indent_level: str = "", start_prompt: Optional[str] = None, end_prompt: Optional[str] = None
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split some code into its indented blocks, starting at a given level.
|
||||
|
||||
Args:
|
||||
code (`str`): The code to split.
|
||||
indent_level (`str`): The indent level (as string) to use for identifying the blocks to split.
|
||||
start_prompt (`str`, *optional*): If provided, only starts splitting at the line where this text is.
|
||||
end_prompt (`str`, *optional*): If provided, stops splitting at a line where this text is.
|
||||
|
||||
Warning:
|
||||
The text before `start_prompt` or after `end_prompt` (if provided) is not ignored, just not split. The input `code`
|
||||
can thus be retrieved by joining the result.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of blocks.
|
||||
"""
|
||||
# Let's split the code into lines and move to start_index.
|
||||
index = 0
|
||||
lines = code.split("\n")
|
||||
if start_prompt is not None:
|
||||
while not lines[index].startswith(start_prompt):
|
||||
index += 1
|
||||
blocks = ["\n".join(lines[:index])]
|
||||
else:
|
||||
blocks = []
|
||||
|
||||
# This variable contains the block treated at a given time.
|
||||
current_block = [lines[index]]
|
||||
index += 1
|
||||
# We split into blocks until we get to the `end_prompt` (or the end of the file).
|
||||
while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)):
|
||||
# We have a non-empty line with the proper indent -> start of a new block
|
||||
if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level:
|
||||
# Store the current block in the result and rest. There are two cases: the line is part of the block (like
|
||||
# a closing parenthesis) or not.
|
||||
if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "):
|
||||
# Line is part of the current block
|
||||
current_block.append(lines[index])
|
||||
blocks.append("\n".join(current_block))
|
||||
if index < len(lines) - 1:
|
||||
current_block = [lines[index + 1]]
|
||||
index += 1
|
||||
else:
|
||||
current_block = []
|
||||
else:
|
||||
# Line is not part of the current block
|
||||
blocks.append("\n".join(current_block))
|
||||
current_block = [lines[index]]
|
||||
else:
|
||||
# Just add the line to the current block
|
||||
current_block.append(lines[index])
|
||||
index += 1
|
||||
|
||||
# Adds current block if it's nonempty.
|
||||
if len(current_block) > 0:
|
||||
blocks.append("\n".join(current_block))
|
||||
|
||||
# Add final block after end_prompt if provided.
|
||||
if end_prompt is not None and index < len(lines):
|
||||
blocks.append("\n".join(lines[index:]))
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def ignore_underscore_and_lowercase(key: Callable[[Any], str]) -> Callable[[Any], str]:
|
||||
"""
|
||||
Wraps a key function (as used in a sort) to lowercase and ignore underscores.
|
||||
"""
|
||||
|
||||
def _inner(x):
|
||||
return key(x).lower().replace("_", "")
|
||||
|
||||
return _inner
|
||||
|
||||
|
||||
def sort_objects(objects: list[Any], key: Optional[Callable[[Any], str]] = None) -> list[Any]:
|
||||
"""
|
||||
Sort a list of objects following the rules of isort (all uppercased first, camel-cased second and lower-cased
|
||||
last).
|
||||
|
||||
Args:
|
||||
objects (`List[Any]`):
|
||||
The list of objects to sort.
|
||||
key (`Callable[[Any], str]`, *optional*):
|
||||
A function taking an object as input and returning a string, used to sort them by alphabetical order.
|
||||
If not provided, will default to noop (so a `key` must be provided if the `objects` are not of type string).
|
||||
|
||||
Returns:
|
||||
`List[Any]`: The sorted list with the same elements as in the inputs
|
||||
"""
|
||||
|
||||
# If no key is provided, we use a noop.
|
||||
def noop(x):
|
||||
return x
|
||||
|
||||
if key is None:
|
||||
key = noop
|
||||
# Constants are all uppercase, they go first.
|
||||
constants = [obj for obj in objects if key(obj).isupper()]
|
||||
# Classes are not all uppercase but start with a capital, they go second.
|
||||
classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()]
|
||||
# Functions begin with a lowercase, they go last.
|
||||
functions = [obj for obj in objects if not key(obj)[0].isupper()]
|
||||
|
||||
# Then we sort each group.
|
||||
key1 = ignore_underscore_and_lowercase(key)
|
||||
return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1)
|
||||
|
||||
|
||||
def sort_objects_in_import(import_statement: str) -> str:
|
||||
"""
|
||||
Sorts the imports in a single import statement.
|
||||
|
||||
Args:
|
||||
import_statement (`str`): The import statement in which to sort the imports.
|
||||
|
||||
Returns:
|
||||
`str`: The same as the input, but with objects properly sorted.
|
||||
"""
|
||||
|
||||
# This inner function sort imports between [ ].
|
||||
def _replace(match):
|
||||
imports = match.groups()[0]
|
||||
# If there is one import only, nothing to do.
|
||||
if "," not in imports:
|
||||
return f"[{imports}]"
|
||||
keys = [part.strip().replace('"', "") for part in imports.split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]"
|
||||
|
||||
lines = import_statement.split("\n")
|
||||
if len(lines) > 3:
|
||||
# Here we have to sort internal imports that are on several lines (one per name):
|
||||
# key: [
|
||||
# "object1",
|
||||
# "object2",
|
||||
# ...
|
||||
# ]
|
||||
|
||||
# We may have to ignore one or two lines on each side.
|
||||
idx = 2 if lines[1].strip() == "[" else 1
|
||||
keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])]
|
||||
sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1])
|
||||
sorted_lines = [lines[x[0] + idx] for x in sorted_indices]
|
||||
return "\n".join(lines[:idx] + sorted_lines + lines[-idx:])
|
||||
elif len(lines) == 3:
|
||||
# Here we have to sort internal imports that are on one separate line:
|
||||
# key: [
|
||||
# "object1", "object2", ...
|
||||
# ]
|
||||
if _re_bracket_content.search(lines[1]) is not None:
|
||||
lines[1] = _re_bracket_content.sub(_replace, lines[1])
|
||||
else:
|
||||
keys = [part.strip().replace('"', "") for part in lines[1].split(",")]
|
||||
# We will have a final empty element if the line finished with a comma.
|
||||
if len(keys[-1]) == 0:
|
||||
keys = keys[:-1]
|
||||
lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)])
|
||||
return "\n".join(lines)
|
||||
else:
|
||||
# Finally we have to deal with imports fitting on one line
|
||||
import_statement = _re_bracket_content.sub(_replace, import_statement)
|
||||
return import_statement
|
||||
|
||||
|
||||
def sort_imports(file: str, check_only: bool = True):
|
||||
"""
|
||||
Sort the imports defined in the `_import_structure` of a given init.
|
||||
|
||||
Args:
|
||||
file (`str`): The path to the init to check/fix.
|
||||
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
|
||||
"""
|
||||
with open(file, encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
|
||||
# If the file is not a custom init, there is nothing to do.
|
||||
if "_import_structure" not in code or "define_import_structure" in code:
|
||||
return
|
||||
|
||||
# Blocks of indent level 0
|
||||
main_blocks = split_code_in_indented_blocks(
|
||||
code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:"
|
||||
)
|
||||
|
||||
# We ignore block 0 (everything until start_prompt) and the last block (everything after end_prompt).
|
||||
for block_idx in range(1, len(main_blocks) - 1):
|
||||
# Check if the block contains some `_import_structure`s thingy to sort.
|
||||
block = main_blocks[block_idx]
|
||||
block_lines = block.split("\n")
|
||||
|
||||
# Get to the start of the imports.
|
||||
line_idx = 0
|
||||
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
|
||||
# Skip dummy import blocks
|
||||
if "import dummy" in block_lines[line_idx]:
|
||||
line_idx = len(block_lines)
|
||||
else:
|
||||
line_idx += 1
|
||||
if line_idx >= len(block_lines):
|
||||
continue
|
||||
|
||||
# Ignore beginning and last line: they don't contain anything.
|
||||
internal_block_code = "\n".join(block_lines[line_idx:-1])
|
||||
indent = get_indent(block_lines[1])
|
||||
# Slit the internal block into blocks of indent level 1.
|
||||
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
|
||||
# We have two categories of import key: list or _import_structure[key].append/extend
|
||||
pattern = _re_direct_key if "_import_structure = {" in block_lines[0] else _re_indirect_key
|
||||
# Grab the keys, but there is a trap: some lines are empty or just comments.
|
||||
keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks]
|
||||
# We only sort the lines with a key.
|
||||
keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None]
|
||||
sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])]
|
||||
|
||||
# We reorder the blocks by leaving empty lines/comments as they were and reorder the rest.
|
||||
count = 0
|
||||
reorderded_blocks = []
|
||||
for i in range(len(internal_blocks)):
|
||||
if keys[i] is None:
|
||||
reorderded_blocks.append(internal_blocks[i])
|
||||
else:
|
||||
block = sort_objects_in_import(internal_blocks[sorted_indices[count]])
|
||||
reorderded_blocks.append(block)
|
||||
count += 1
|
||||
|
||||
# And we put our main block back together with its first and last line.
|
||||
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
|
||||
|
||||
if code != "\n".join(main_blocks):
|
||||
if check_only:
|
||||
return True
|
||||
else:
|
||||
print(f"Overwriting {file}.")
|
||||
with open(file, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(main_blocks))
|
||||
|
||||
|
||||
def sort_imports_in_all_inits(check_only=True):
|
||||
"""
|
||||
Sort the imports defined in the `_import_structure` of all inits in the repo.
|
||||
|
||||
Args:
|
||||
check_only (`bool`, *optional*, defaults to `True`): Whether or not to just check (and not auto-fix) the init.
|
||||
"""
|
||||
failures = []
|
||||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS):
|
||||
if "__init__.py" in files:
|
||||
result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only)
|
||||
if result:
|
||||
failures = [os.path.join(root, "__init__.py")]
|
||||
if len(failures) > 0:
|
||||
raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||
args = parser.parse_args()
|
||||
|
||||
sort_imports_in_all_inits(check_only=args.check_only)
|
||||
378
transformers/utils/deprecate_models.py
Normal file
378
transformers/utils/deprecate_models.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
Script which deprecates a list of given models
|
||||
|
||||
Example usage:
|
||||
python utils/deprecate_models.py --models bert distilbert
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from custom_init_isort import sort_imports_in_all_inits
|
||||
from git import Repo
|
||||
from packaging import version
|
||||
|
||||
from transformers import CONFIG_MAPPING, logging
|
||||
from transformers import __version__ as current_version
|
||||
|
||||
|
||||
REPO_PATH = Path(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
|
||||
repo = Repo(REPO_PATH)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_last_stable_minor_release():
|
||||
# Get the last stable release of transformers
|
||||
url = "https://pypi.org/pypi/transformers/json"
|
||||
release_data = requests.get(url).json()
|
||||
|
||||
# Find the last stable release of transformers (version below current version)
|
||||
major_version, minor_version, patch_version, _ = current_version.split(".")
|
||||
last_major_minor = f"{major_version}.{int(minor_version) - 1}"
|
||||
last_stable_minor_releases = [
|
||||
release for release in release_data["releases"] if release.startswith(last_major_minor)
|
||||
]
|
||||
last_stable_release = sorted(last_stable_minor_releases, key=version.parse)[-1]
|
||||
|
||||
return last_stable_release
|
||||
|
||||
|
||||
def build_tip_message(last_stable_release):
|
||||
return (
|
||||
"""
|
||||
<Tip warning={true}>
|
||||
|
||||
This model is in maintenance mode only, we don't accept any new PRs changing its code.
|
||||
"""
|
||||
+ f"""If you run into any issues running this model, please reinstall the last version that supported this model: v{last_stable_release}.
|
||||
You can do so by running the following command: `pip install -U transformers=={last_stable_release}`.
|
||||
|
||||
</Tip>"""
|
||||
)
|
||||
|
||||
|
||||
def insert_tip_to_model_doc(model_doc_path, tip_message):
|
||||
tip_message_lines = tip_message.split("\n")
|
||||
|
||||
with open(model_doc_path, "r") as f:
|
||||
model_doc = f.read()
|
||||
|
||||
# Add the tip message to the model doc page directly underneath the title
|
||||
lines = model_doc.split("\n")
|
||||
|
||||
new_model_lines = []
|
||||
for line in lines:
|
||||
if line.startswith("# "):
|
||||
new_model_lines.append(line)
|
||||
new_model_lines.extend(tip_message_lines)
|
||||
else:
|
||||
new_model_lines.append(line)
|
||||
|
||||
with open(model_doc_path, "w") as f:
|
||||
f.write("\n".join(new_model_lines))
|
||||
|
||||
|
||||
def get_model_doc_path(model: str) -> tuple[Optional[str], Optional[str]]:
|
||||
# Possible variants of the model name in the model doc path
|
||||
model_names = [model, model.replace("_", "-"), model.replace("_", "")]
|
||||
|
||||
model_doc_paths = [REPO_PATH / f"docs/source/en/model_doc/{model_name}.md" for model_name in model_names]
|
||||
|
||||
for model_doc_path, model_name in zip(model_doc_paths, model_names):
|
||||
if os.path.exists(model_doc_path):
|
||||
return model_doc_path, model_name
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def extract_model_info(model):
|
||||
model_info = {}
|
||||
model_doc_path, model_doc_name = get_model_doc_path(model)
|
||||
model_path = REPO_PATH / f"src/transformers/models/{model}"
|
||||
|
||||
if model_doc_path is None:
|
||||
print(f"Model doc path does not exist for {model}")
|
||||
return None
|
||||
model_info["model_doc_path"] = model_doc_path
|
||||
model_info["model_doc_name"] = model_doc_name
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"Model path does not exist for {model}")
|
||||
return None
|
||||
model_info["model_path"] = model_path
|
||||
|
||||
return model_info
|
||||
|
||||
|
||||
def update_relative_imports(filename, model):
|
||||
with open(filename, "r") as f:
|
||||
filelines = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
for line in filelines.split("\n"):
|
||||
if line.startswith("from .."):
|
||||
new_file_lines.append(line.replace("from ..", "from ..."))
|
||||
else:
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def remove_copied_from_statements(model):
|
||||
model_path = REPO_PATH / f"src/transformers/models/{model}"
|
||||
for file in os.listdir(model_path):
|
||||
if file == "__pycache__":
|
||||
continue
|
||||
file_path = model_path / file
|
||||
with open(file_path, "r") as f:
|
||||
file_lines = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
for line in file_lines.split("\n"):
|
||||
if "# Copied from" in line:
|
||||
continue
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def move_model_files_to_deprecated(model):
|
||||
model_path = REPO_PATH / f"src/transformers/models/{model}"
|
||||
deprecated_model_path = REPO_PATH / f"src/transformers/models/deprecated/{model}"
|
||||
|
||||
if not os.path.exists(deprecated_model_path):
|
||||
os.makedirs(deprecated_model_path)
|
||||
|
||||
for file in os.listdir(model_path):
|
||||
if file == "__pycache__":
|
||||
continue
|
||||
repo.git.mv(f"{model_path}/{file}", f"{deprecated_model_path}/{file}")
|
||||
|
||||
# For deprecated files, we then need to update the relative imports
|
||||
update_relative_imports(f"{deprecated_model_path}/{file}", model)
|
||||
|
||||
|
||||
def delete_model_tests(model):
|
||||
tests_path = REPO_PATH / f"tests/models/{model}"
|
||||
|
||||
if os.path.exists(tests_path):
|
||||
repo.git.rm("-r", tests_path)
|
||||
|
||||
|
||||
def get_line_indent(s):
|
||||
return len(s) - len(s.lstrip())
|
||||
|
||||
|
||||
def update_main_init_file(models):
|
||||
"""
|
||||
Replace all instances of model.model_name with model.deprecated.model_name in the __init__.py file
|
||||
|
||||
Args:
|
||||
models (List[str]): The models to mark as deprecated
|
||||
"""
|
||||
filename = REPO_PATH / "src/transformers/__init__.py"
|
||||
with open(filename, "r") as f:
|
||||
init_file = f.read()
|
||||
|
||||
# 1. For each model, find all the instances of model.model_name and replace with model.deprecated.model_name
|
||||
for model in models:
|
||||
init_file = init_file.replace(f'models.{model}"', f'models.deprecated.{model}"')
|
||||
init_file = init_file.replace(f"models.{model} import", f"models.deprecated.{model} import")
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(init_file)
|
||||
|
||||
# 2. Resort the imports
|
||||
sort_imports_in_all_inits(check_only=False)
|
||||
|
||||
|
||||
def remove_model_references_from_file(filename, models, condition):
|
||||
"""
|
||||
Remove all references to the given models from the given file
|
||||
|
||||
Args:
|
||||
filename (str): The file to remove the references from
|
||||
models (List[str]): The models to remove
|
||||
condition (Callable): A function that takes the line and model and returns True if the line should be removed
|
||||
"""
|
||||
filename = REPO_PATH / filename
|
||||
with open(filename, "r") as f:
|
||||
init_file = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
for i, line in enumerate(init_file.split("\n")):
|
||||
if any(condition(line, model) for model in models):
|
||||
continue
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def remove_model_config_classes_from_config_check(model_config_classes):
|
||||
"""
|
||||
Remove the deprecated model config classes from the check_config_attributes.py file
|
||||
|
||||
Args:
|
||||
model_config_classes (List[str]): The model config classes to remove e.g. ["BertConfig", "DistilBertConfig"]
|
||||
"""
|
||||
filename = REPO_PATH / "utils/check_config_attributes.py"
|
||||
with open(filename, "r") as f:
|
||||
check_config_attributes = f.read()
|
||||
|
||||
# Keep track as we have to delete comment above too
|
||||
in_special_cases_to_allow = False
|
||||
in_indent = False
|
||||
new_file_lines = []
|
||||
|
||||
for line in check_config_attributes.split("\n"):
|
||||
indent = get_line_indent(line)
|
||||
if (line.strip() == "SPECIAL_CASES_TO_ALLOW = {") or (line.strip() == "SPECIAL_CASES_TO_ALLOW.update("):
|
||||
in_special_cases_to_allow = True
|
||||
|
||||
elif in_special_cases_to_allow and indent == 0 and line.strip() in ("}", ")"):
|
||||
in_special_cases_to_allow = False
|
||||
|
||||
if in_indent:
|
||||
if line.strip().endswith(("]", "],")):
|
||||
in_indent = False
|
||||
continue
|
||||
|
||||
if in_special_cases_to_allow and any(
|
||||
model_config_class in line for model_config_class in model_config_classes
|
||||
):
|
||||
# Remove comments above the model config class to remove
|
||||
while new_file_lines[-1].strip().startswith("#"):
|
||||
new_file_lines.pop()
|
||||
|
||||
if line.strip().endswith("["):
|
||||
in_indent = True
|
||||
|
||||
continue
|
||||
|
||||
elif any(model_config_class in line for model_config_class in model_config_classes):
|
||||
continue
|
||||
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def add_models_to_deprecated_models_in_config_auto(models):
|
||||
"""
|
||||
Add the models to the DEPRECATED_MODELS list in configuration_auto.py and sorts the list
|
||||
to be in alphabetical order.
|
||||
"""
|
||||
filepath = REPO_PATH / "src/transformers/models/auto/configuration_auto.py"
|
||||
with open(filepath, "r") as f:
|
||||
config_auto = f.read()
|
||||
|
||||
new_file_lines = []
|
||||
deprecated_models_list = []
|
||||
in_deprecated_models = False
|
||||
for line in config_auto.split("\n"):
|
||||
if line.strip() == "DEPRECATED_MODELS = [":
|
||||
in_deprecated_models = True
|
||||
new_file_lines.append(line)
|
||||
elif in_deprecated_models and line.strip() == "]":
|
||||
in_deprecated_models = False
|
||||
# Add the new models to deprecated models list
|
||||
deprecated_models_list.extend([f' "{model}", ' for model in models])
|
||||
# Sort so they're in alphabetical order in the file
|
||||
deprecated_models_list = sorted(deprecated_models_list)
|
||||
new_file_lines.extend(deprecated_models_list)
|
||||
# Make sure we still have the closing bracket
|
||||
new_file_lines.append(line)
|
||||
elif in_deprecated_models:
|
||||
deprecated_models_list.append(line)
|
||||
else:
|
||||
new_file_lines.append(line)
|
||||
|
||||
with open(filepath, "w") as f:
|
||||
f.write("\n".join(new_file_lines))
|
||||
|
||||
|
||||
def deprecate_models(models):
|
||||
# Get model info
|
||||
skipped_models = []
|
||||
models_info = defaultdict(dict)
|
||||
for model in models:
|
||||
single_model_info = extract_model_info(model)
|
||||
if single_model_info is None:
|
||||
skipped_models.append(model)
|
||||
else:
|
||||
models_info[model] = single_model_info
|
||||
|
||||
model_config_classes = []
|
||||
for model, model_info in models_info.items():
|
||||
if model in CONFIG_MAPPING:
|
||||
model_config_classes.append(CONFIG_MAPPING[model].__name__)
|
||||
elif model_info["model_doc_name"] in CONFIG_MAPPING:
|
||||
model_config_classes.append(CONFIG_MAPPING[model_info["model_doc_name"]].__name__)
|
||||
else:
|
||||
skipped_models.append(model)
|
||||
print(f"Model config class not found for model: {model}")
|
||||
|
||||
# Filter out skipped models
|
||||
models = [model for model in models if model not in skipped_models]
|
||||
|
||||
if skipped_models:
|
||||
print(f"Skipped models: {skipped_models} as the model doc or model path could not be found.")
|
||||
print(f"Models to deprecate: {models}")
|
||||
|
||||
# Remove model config classes from config check
|
||||
print("Removing model config classes from config checks")
|
||||
remove_model_config_classes_from_config_check(model_config_classes)
|
||||
|
||||
tip_message = build_tip_message(get_last_stable_minor_release())
|
||||
|
||||
for model, model_info in models_info.items():
|
||||
print(f"Processing model: {model}")
|
||||
# Add the tip message to the model doc page directly underneath the title
|
||||
print("Adding tip message to model doc page")
|
||||
insert_tip_to_model_doc(model_info["model_doc_path"], tip_message)
|
||||
|
||||
# Remove #Copied from statements from model's files
|
||||
print("Removing #Copied from statements from model's files")
|
||||
remove_copied_from_statements(model)
|
||||
|
||||
# Move the model file to deprecated: src/transformers/models/model -> src/transformers/models/deprecated/model
|
||||
print("Moving model files to deprecated for model")
|
||||
move_model_files_to_deprecated(model)
|
||||
|
||||
# Delete the model tests: tests/models/model
|
||||
print("Deleting model tests")
|
||||
delete_model_tests(model)
|
||||
|
||||
# # We do the following with all models passed at once to avoid having to re-write the file multiple times
|
||||
print("Updating __init__.py file to point to the deprecated models")
|
||||
update_main_init_file(models)
|
||||
|
||||
# Remove model references from other files
|
||||
print("Removing model references from other files")
|
||||
remove_model_references_from_file(
|
||||
"src/transformers/models/__init__.py", models, lambda line, model: model == line.strip().strip(",")
|
||||
)
|
||||
remove_model_references_from_file(
|
||||
"utils/slow_documentation_tests.txt", models, lambda line, model: "/" + model + "/" in line
|
||||
)
|
||||
remove_model_references_from_file("utils/not_doctested.txt", models, lambda line, model: "/" + model + "/" in line)
|
||||
|
||||
# Add models to DEPRECATED_MODELS in the configuration_auto.py
|
||||
print("Adding models to DEPRECATED_MODELS in configuration_auto.py")
|
||||
add_models_to_deprecated_models_in_config_auto(models)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--models", nargs="+", help="List of models to deprecate")
|
||||
args = parser.parse_args()
|
||||
deprecate_models(args.models)
|
||||
160
transformers/utils/download_glue_data.py
Normal file
160
transformers/utils/download_glue_data.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Script for downloading all GLUE data.
|
||||
Original source: https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
|
||||
|
||||
Note: for legal reasons, we are unable to host MRPC.
|
||||
You can either use the version hosted by the SentEval team, which is already tokenized,
|
||||
or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.
|
||||
For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).
|
||||
You should then rename and place specific files in a folder (see below for an example).
|
||||
|
||||
mkdir MRPC
|
||||
cabextract MSRParaphraseCorpus.msi -d MRPC
|
||||
cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt
|
||||
cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt
|
||||
rm MRPC/_*
|
||||
rm MSRParaphraseCorpus.msi
|
||||
|
||||
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
|
||||
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
|
||||
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
|
||||
TASK2PATH = {
|
||||
"CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",
|
||||
"SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",
|
||||
"MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",
|
||||
"QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5",
|
||||
"STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",
|
||||
"MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
|
||||
"SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",
|
||||
"QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",
|
||||
"RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",
|
||||
"WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",
|
||||
"diagnostic": "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",
|
||||
}
|
||||
|
||||
MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
|
||||
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
|
||||
|
||||
|
||||
def download_and_extract(task, data_dir):
|
||||
print(f"Downloading and extracting {task}...")
|
||||
data_file = f"{task}.zip"
|
||||
urllib.request.urlretrieve(TASK2PATH[task], data_file)
|
||||
with zipfile.ZipFile(data_file) as zip_ref:
|
||||
zip_ref.extractall(data_dir)
|
||||
os.remove(data_file)
|
||||
print("\tCompleted!")
|
||||
|
||||
|
||||
def format_mrpc(data_dir, path_to_data):
|
||||
print("Processing MRPC...")
|
||||
mrpc_dir = os.path.join(data_dir, "MRPC")
|
||||
if not os.path.isdir(mrpc_dir):
|
||||
os.mkdir(mrpc_dir)
|
||||
if path_to_data:
|
||||
mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")
|
||||
mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")
|
||||
else:
|
||||
print("Local MRPC data not specified, downloading data from %s" % MRPC_TRAIN)
|
||||
mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")
|
||||
mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")
|
||||
urllib.request.urlretrieve(MRPC_TRAIN, mrpc_train_file)
|
||||
urllib.request.urlretrieve(MRPC_TEST, mrpc_test_file)
|
||||
if not os.path.isfile(mrpc_train_file):
|
||||
raise ValueError(f"Train data not found at {mrpc_train_file}")
|
||||
if not os.path.isfile(mrpc_test_file):
|
||||
raise ValueError(f"Test data not found at {mrpc_test_file}")
|
||||
urllib.request.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))
|
||||
|
||||
dev_ids = []
|
||||
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
|
||||
for row in ids_fh:
|
||||
dev_ids.append(row.strip().split("\t"))
|
||||
|
||||
with (
|
||||
open(mrpc_train_file, encoding="utf8") as data_fh,
|
||||
open(os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8") as train_fh,
|
||||
open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh,
|
||||
):
|
||||
header = data_fh.readline()
|
||||
train_fh.write(header)
|
||||
dev_fh.write(header)
|
||||
for row in data_fh:
|
||||
label, id1, id2, s1, s2 = row.strip().split("\t")
|
||||
if [id1, id2] in dev_ids:
|
||||
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
|
||||
else:
|
||||
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
|
||||
|
||||
with (
|
||||
open(mrpc_test_file, encoding="utf8") as data_fh,
|
||||
open(os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8") as test_fh,
|
||||
):
|
||||
header = data_fh.readline()
|
||||
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
|
||||
for idx, row in enumerate(data_fh):
|
||||
label, id1, id2, s1, s2 = row.strip().split("\t")
|
||||
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
|
||||
print("\tCompleted!")
|
||||
|
||||
|
||||
def download_diagnostic(data_dir):
|
||||
print("Downloading and extracting diagnostic...")
|
||||
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
|
||||
os.mkdir(os.path.join(data_dir, "diagnostic"))
|
||||
data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")
|
||||
urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)
|
||||
print("\tCompleted!")
|
||||
return
|
||||
|
||||
|
||||
def get_tasks(task_names):
|
||||
task_names = task_names.split(",")
|
||||
if "all" in task_names:
|
||||
tasks = TASKS
|
||||
else:
|
||||
tasks = []
|
||||
for task_name in task_names:
|
||||
if task_name not in TASKS:
|
||||
raise ValueError(f"Task {task_name} not found!")
|
||||
tasks.append(task_name)
|
||||
return tasks
|
||||
|
||||
|
||||
def main(arguments):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_dir", help="directory to save data to", type=str, default="glue_data")
|
||||
parser.add_argument(
|
||||
"--tasks", help="tasks to download data for as a comma separated string", type=str, default="all"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_to_mrpc",
|
||||
help="path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
args = parser.parse_args(arguments)
|
||||
|
||||
if not os.path.isdir(args.data_dir):
|
||||
os.mkdir(args.data_dir)
|
||||
tasks = get_tasks(args.tasks)
|
||||
|
||||
for task in tasks:
|
||||
if task == "MRPC":
|
||||
format_mrpc(args.data_dir, args.path_to_mrpc)
|
||||
elif task == "diagnostic":
|
||||
download_diagnostic(args.data_dir)
|
||||
else:
|
||||
download_and_extract(task, args.data_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
31
transformers/utils/extract_pr_number_from_circleci.py
Normal file
31
transformers/utils/extract_pr_number_from_circleci.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Used by `.github/workflows/trigger_circleci.yml` to get the pull request number in CircleCI job runs."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pr_number = ""
|
||||
|
||||
pr = os.environ.get("CIRCLE_PULL_REQUEST", "")
|
||||
if len(pr) > 0:
|
||||
pr_number = pr.split("/")[-1]
|
||||
if pr_number == "":
|
||||
pr = os.environ.get("CIRCLE_BRANCH", "")
|
||||
if pr.startswith("pull/"):
|
||||
pr_number = "".join(pr.split("/")[1:2])
|
||||
|
||||
print(pr_number)
|
||||
134
transformers/utils/extract_warnings.py
Normal file
134
transformers/utils/extract_warnings.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
|
||||
from get_ci_error_statistics import download_artifact, get_artifacts_links
|
||||
|
||||
from transformers import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def extract_warnings_from_single_artifact(artifact_path, targets):
|
||||
"""Extract warnings from a downloaded artifact (in .zip format)"""
|
||||
selected_warnings = set()
|
||||
buffer = []
|
||||
|
||||
def parse_line(fp):
|
||||
for line in fp:
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("UTF-8")
|
||||
if "warnings summary (final)" in line:
|
||||
continue
|
||||
# This means we are outside the body of a warning
|
||||
elif not line.startswith(" "):
|
||||
# process a single warning and move it to `selected_warnings`.
|
||||
if len(buffer) > 0:
|
||||
warning = "\n".join(buffer)
|
||||
# Only keep the warnings specified in `targets`
|
||||
if any(f": {x}: " in warning for x in targets):
|
||||
selected_warnings.add(warning)
|
||||
buffer.clear()
|
||||
continue
|
||||
else:
|
||||
line = line.strip()
|
||||
buffer.append(line)
|
||||
|
||||
if from_gh:
|
||||
for filename in os.listdir(artifact_path):
|
||||
file_path = os.path.join(artifact_path, filename)
|
||||
if not os.path.isdir(file_path):
|
||||
# read the file
|
||||
if filename != "warnings.txt":
|
||||
continue
|
||||
with open(file_path) as fp:
|
||||
parse_line(fp)
|
||||
else:
|
||||
try:
|
||||
with zipfile.ZipFile(artifact_path) as z:
|
||||
for filename in z.namelist():
|
||||
if not os.path.isdir(filename):
|
||||
# read the file
|
||||
if filename != "warnings.txt":
|
||||
continue
|
||||
with z.open(filename) as fp:
|
||||
parse_line(fp)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"{artifact_path} is either an invalid zip file or something else wrong. This file is skipped."
|
||||
)
|
||||
|
||||
return selected_warnings
|
||||
|
||||
|
||||
def extract_warnings(artifact_dir, targets):
|
||||
"""Extract warnings from all artifact files"""
|
||||
|
||||
selected_warnings = set()
|
||||
|
||||
paths = [os.path.join(artifact_dir, p) for p in os.listdir(artifact_dir) if (p.endswith(".zip") or from_gh)]
|
||||
for p in paths:
|
||||
selected_warnings.update(extract_warnings_from_single_artifact(p, targets))
|
||||
|
||||
return selected_warnings
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def list_str(values):
|
||||
return values.split(",")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the downloaded artifacts and other result files.",
|
||||
)
|
||||
parser.add_argument("--token", default=None, type=str, help="A token that has actions:read permission.")
|
||||
# optional parameters
|
||||
parser.add_argument(
|
||||
"--targets",
|
||||
default="DeprecationWarning,UserWarning,FutureWarning",
|
||||
type=list_str,
|
||||
help="Comma-separated list of target warning(s) which we want to extract.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from_gh",
|
||||
action="store_true",
|
||||
help="If running from a GitHub action workflow and collecting warnings from its artifacts.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
from_gh = args.from_gh
|
||||
if from_gh:
|
||||
# The artifacts have to be downloaded using `actions/download-artifact@v4`
|
||||
pass
|
||||
else:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# get download links
|
||||
artifacts = get_artifacts_links(args.workflow_run_id, token=args.token)
|
||||
with open(os.path.join(args.output_dir, "artifacts.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(artifacts, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
# download artifacts
|
||||
for idx, (name, url) in enumerate(artifacts.items()):
|
||||
print(name)
|
||||
print(url)
|
||||
print("=" * 80)
|
||||
download_artifact(name, url, args.output_dir, args.token)
|
||||
# Be gentle to GitHub
|
||||
time.sleep(1)
|
||||
|
||||
# extract warnings from artifacts
|
||||
selected_warnings = extract_warnings(args.output_dir, args.targets)
|
||||
selected_warnings = sorted(selected_warnings)
|
||||
with open(os.path.join(args.output_dir, "selected_warnings.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(selected_warnings, fp, ensure_ascii=False, indent=4)
|
||||
217
transformers/utils/fetch_hub_objects_for_ci.py
Normal file
217
transformers/utils/fetch_hub_objects_for_ci.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from transformers.testing_utils import _run_pipeline_tests, _run_staging
|
||||
from transformers.utils.import_utils import is_mistral_common_available
|
||||
|
||||
|
||||
URLS_FOR_TESTING_DATA = [
|
||||
"http://images.cocodataset.org/val2017/000000000139.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000285.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000632.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000724.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000776.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000785.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000802.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000000872.jpg",
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
"https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg",
|
||||
"https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/bus.png",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/australia.jpg",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/f2641_0_throatclearing.wav",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/glass-breaking-151256.mp3",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/images_test/resolve/main/picsum_237_200x300.jpg",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/Big_Buck_Bunny_720_10s_10MB.mp4",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/sample_demo_1.mp4",
|
||||
"https://huggingface.co/kirp/kosmos2_5/resolve/main/receipt_00008.png",
|
||||
"https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/two_dogs.jpg",
|
||||
"https://llava-vl.github.io/static/images/view.jpg",
|
||||
"https://huggingface.co/datasets/hf-internal-testing/fixtures_videos/resolve/main/tennis.mp4",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/tiny_video.mp4",
|
||||
"https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg",
|
||||
"https://huggingface.co/datasets/raushan-testing-hf/videos-test/resolve/main/tiny_video.mp4",
|
||||
]
|
||||
|
||||
|
||||
def url_to_local_path(url, return_url_if_not_found=True):
|
||||
filename = url.split("/")[-1]
|
||||
|
||||
if not os.path.exists(filename) and return_url_if_not_found:
|
||||
return url
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if _run_pipeline_tests:
|
||||
import datasets
|
||||
|
||||
_ = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
_ = datasets.load_dataset("hf-internal-testing/fixtures_image_utils", split="test", revision="refs/pr/1")
|
||||
_ = hf_hub_download(repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset")
|
||||
|
||||
hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
|
||||
hf_hub_download(repo_id="hf-internal-testing/bool-masked-pos", filename="bool_masked_pos.pt")
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/fixtures_docvqa",
|
||||
filename="nougat_pdf.png",
|
||||
repo_type="dataset",
|
||||
revision="ec57bf8c8b1653a209c13f6e9ee66b12df0fc2db",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/spaghetti-video",
|
||||
filename="eating_spaghetti_32_frames.npy",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/spaghetti-video",
|
||||
filename="eating_spaghetti_8_frames.npy",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(repo_id="huggyllama/llama-7b", filename="tokenizer.model")
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint", filename="sample_audio.flac", repo_type="dataset"
|
||||
)
|
||||
hf_hub_download(repo_id="nielsr/example-pdf", repo_type="dataset", filename="example_pdf.png")
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/test-image",
|
||||
filename="llava_1_6_input_ids.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/test-image",
|
||||
filename="llava_1_6_pixel_values.pt",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
|
||||
hf_hub_download(
|
||||
repo_id="raushan-testing-hf/images_test",
|
||||
filename="emu3_image.npy",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(repo_id="raushan-testing-hf/images_test", filename="llava_v1_5_radar.jpg", repo_type="dataset")
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset")
|
||||
hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="video_demo_2.npy", repo_type="dataset")
|
||||
hf_hub_download(
|
||||
repo_id="shumingh/perception_lm_test_images",
|
||||
filename="14496_0.PNG",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="shumingh/perception_lm_test_videos",
|
||||
filename="GUWR5TyiY-M_000012_000022.mp4",
|
||||
repo_type="dataset",
|
||||
)
|
||||
repo_id = "nielsr/image-segmentation-toy-data"
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_image_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_image_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_annotation_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="instance_segmentation_annotation_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_annotation_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_annotation_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_image_1.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/image-segmentation-toy-data",
|
||||
filename="semantic_segmentation_image_2.png",
|
||||
repo_type="dataset",
|
||||
)
|
||||
hf_hub_download("shi-labs/oneformer_demo", "ade20k_panoptic.json", repo_type="dataset")
|
||||
|
||||
hf_hub_download(
|
||||
repo_id="nielsr/audio-spectogram-transformer-checkpoint", filename="sample_audio.flac", repo_type="dataset"
|
||||
)
|
||||
|
||||
# Need to specify the username on the endpoint `hub-ci`, otherwise we get
|
||||
# `fatal: could not read Username for 'https://hub-ci.huggingface.co': Success`
|
||||
# But this repo. is never used in a test decorated by `is_staging_test`.
|
||||
if not _run_staging:
|
||||
if not os.path.isdir("tiny-random-custom-architecture"):
|
||||
snapshot_download(
|
||||
"hf-internal-testing/tiny-random-custom-architecture",
|
||||
local_dir="tiny-random-custom-architecture",
|
||||
)
|
||||
|
||||
# For `tests/test_tokenization_mistral_common.py:TestMistralCommonTokenizer`, which eventually calls
|
||||
# `mistral_common.tokens.tokenizers.utils.download_tokenizer_from_hf_hub` which (probably) doesn't have the cache.
|
||||
if is_mistral_common_available():
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.tokenization_mistral_common import MistralCommonTokenizer
|
||||
|
||||
repo_id = "hf-internal-testing/namespace-mistralai-repo_name-Mistral-Small-3.1-24B-Instruct-2503"
|
||||
AutoTokenizer.from_pretrained(repo_id, tokenizer_type="mistral")
|
||||
MistralCommonTokenizer.from_pretrained(repo_id)
|
||||
MistralTokenizer.from_hf_hub(repo_id)
|
||||
|
||||
repo_id = "mistralai/Voxtral-Mini-3B-2507"
|
||||
AutoTokenizer.from_pretrained(repo_id)
|
||||
MistralTokenizer.from_hf_hub(repo_id)
|
||||
|
||||
# Download files from URLs to local directory
|
||||
for url in URLS_FOR_TESTING_DATA:
|
||||
filename = url_to_local_path(url, return_url_if_not_found=False)
|
||||
|
||||
# Skip if file already exists
|
||||
if os.path.exists(filename):
|
||||
print(f"File already exists: {filename}")
|
||||
continue
|
||||
|
||||
print(f"Downloading {filename}...")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
print(f"Successfully downloaded: {filename}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error downloading {filename}: {e}")
|
||||
305
transformers/utils/get_ci_error_statistics.py
Normal file
305
transformers/utils/get_ci_error_statistics.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
import zipfile
|
||||
from collections import Counter
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def get_jobs(workflow_run_id, token=None):
|
||||
"""Extract jobs in a GitHub Actions workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
|
||||
result = requests.get(url, headers=headers).json()
|
||||
jobs = []
|
||||
|
||||
try:
|
||||
jobs.extend(result["jobs"])
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
jobs.extend(result["jobs"])
|
||||
|
||||
return jobs
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def get_job_links(workflow_run_id, token=None):
|
||||
"""Extract job names and their job links in a GitHub Actions workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
|
||||
result = requests.get(url, headers=headers).json()
|
||||
job_links = {}
|
||||
|
||||
try:
|
||||
job_links.update({job["name"]: job["html_url"] for job in result["jobs"]})
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
job_links.update({job["name"]: job["html_url"] for job in result["jobs"]})
|
||||
|
||||
return job_links
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def get_artifacts_links(workflow_run_id, token=None):
|
||||
"""Get all artifact links from a workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = (
|
||||
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/artifacts?per_page=100"
|
||||
)
|
||||
result = requests.get(url, headers=headers).json()
|
||||
artifacts = {}
|
||||
|
||||
try:
|
||||
artifacts.update({artifact["name"]: artifact["archive_download_url"] for artifact in result["artifacts"]})
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
artifacts.update({artifact["name"]: artifact["archive_download_url"] for artifact in result["artifacts"]})
|
||||
|
||||
return artifacts
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def download_artifact(artifact_name, artifact_url, output_dir, token):
|
||||
"""Download a GitHub Action artifact from a URL.
|
||||
|
||||
The URL is of the form `https://api.github.com/repos/huggingface/transformers/actions/artifacts/{ARTIFACT_ID}/zip`,
|
||||
but it can't be used to download directly. We need to get a redirect URL first.
|
||||
See https://docs.github.com/en/rest/actions/artifacts#download-an-artifact
|
||||
"""
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
result = requests.get(artifact_url, headers=headers, allow_redirects=False)
|
||||
download_url = result.headers["Location"]
|
||||
response = requests.get(download_url, allow_redirects=True)
|
||||
file_path = os.path.join(output_dir, f"{artifact_name}.zip")
|
||||
with open(file_path, "wb") as fp:
|
||||
fp.write(response.content)
|
||||
|
||||
|
||||
def get_errors_from_single_artifact(artifact_zip_path, job_links=None):
|
||||
"""Extract errors from a downloaded artifact (in .zip format)"""
|
||||
errors = []
|
||||
failed_tests = []
|
||||
job_name = None
|
||||
|
||||
with zipfile.ZipFile(artifact_zip_path) as z:
|
||||
for filename in z.namelist():
|
||||
if not os.path.isdir(filename):
|
||||
# read the file
|
||||
if filename in ["failures_line.txt", "summary_short.txt", "job_name.txt"]:
|
||||
with z.open(filename) as f:
|
||||
for line in f:
|
||||
line = line.decode("UTF-8").strip()
|
||||
if filename == "failures_line.txt":
|
||||
try:
|
||||
# `error_line` is the place where `error` occurs
|
||||
error_line = line[: line.index(": ")]
|
||||
error = line[line.index(": ") + len(": ") :]
|
||||
errors.append([error_line, error])
|
||||
except Exception:
|
||||
# skip un-related lines
|
||||
pass
|
||||
elif filename == "summary_short.txt" and line.startswith("FAILED "):
|
||||
# `test` is the test method that failed
|
||||
test = line[len("FAILED ") :]
|
||||
failed_tests.append(test)
|
||||
elif filename == "job_name.txt":
|
||||
job_name = line
|
||||
|
||||
if len(errors) != len(failed_tests):
|
||||
raise ValueError(
|
||||
f"`errors` and `failed_tests` should have the same number of elements. Got {len(errors)} for `errors` "
|
||||
f"and {len(failed_tests)} for `failed_tests` instead. The test reports in {artifact_zip_path} have some"
|
||||
" problem."
|
||||
)
|
||||
|
||||
job_link = None
|
||||
if job_name and job_links:
|
||||
job_link = job_links.get(job_name, None)
|
||||
|
||||
# A list with elements of the form (line of error, error, failed test)
|
||||
result = [x + [y] + [job_link] for x, y in zip(errors, failed_tests)]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_all_errors(artifact_dir, job_links=None):
|
||||
"""Extract errors from all artifact files"""
|
||||
|
||||
errors = []
|
||||
|
||||
paths = [os.path.join(artifact_dir, p) for p in os.listdir(artifact_dir) if p.endswith(".zip")]
|
||||
for p in paths:
|
||||
errors.extend(get_errors_from_single_artifact(p, job_links=job_links))
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def reduce_by_error(logs, error_filter=None):
|
||||
"""count each error"""
|
||||
|
||||
counter = Counter()
|
||||
counter.update([x[1] for x in logs])
|
||||
counts = counter.most_common()
|
||||
r = {}
|
||||
for error, count in counts:
|
||||
if error_filter is None or error not in error_filter:
|
||||
r[error] = {"count": count, "failed_tests": [(x[2], x[0]) for x in logs if x[1] == error]}
|
||||
|
||||
r = dict(sorted(r.items(), key=lambda item: item[1]["count"], reverse=True))
|
||||
return r
|
||||
|
||||
|
||||
def get_model(test):
|
||||
"""Get the model name from a test method"""
|
||||
test = test.split("::")[0]
|
||||
if test.startswith("tests/models/"):
|
||||
test = test.split("/")[2]
|
||||
else:
|
||||
test = None
|
||||
|
||||
return test
|
||||
|
||||
|
||||
def reduce_by_model(logs, error_filter=None):
|
||||
"""count each error per model"""
|
||||
|
||||
logs = [(x[0], x[1], get_model(x[2])) for x in logs]
|
||||
logs = [x for x in logs if x[2] is not None]
|
||||
tests = {x[2] for x in logs}
|
||||
|
||||
r = {}
|
||||
for test in tests:
|
||||
counter = Counter()
|
||||
# count by errors in `test`
|
||||
counter.update([x[1] for x in logs if x[2] == test])
|
||||
counts = counter.most_common()
|
||||
error_counts = {error: count for error, count in counts if (error_filter is None or error not in error_filter)}
|
||||
n_errors = sum(error_counts.values())
|
||||
if n_errors > 0:
|
||||
r[test] = {"count": n_errors, "errors": error_counts}
|
||||
|
||||
r = dict(sorted(r.items(), key=lambda item: item[1]["count"], reverse=True))
|
||||
return r
|
||||
|
||||
|
||||
def make_github_table(reduced_by_error):
|
||||
header = "| no. | error | status |"
|
||||
sep = "|-:|:-|:-|"
|
||||
lines = [header, sep]
|
||||
for error in reduced_by_error:
|
||||
count = reduced_by_error[error]["count"]
|
||||
line = f"| {count} | {error[:100]} | |"
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def make_github_table_per_model(reduced_by_model):
|
||||
header = "| model | no. of errors | major error | count |"
|
||||
sep = "|-:|-:|-:|-:|"
|
||||
lines = [header, sep]
|
||||
for model in reduced_by_model:
|
||||
count = reduced_by_model[model]["count"]
|
||||
error, _count = list(reduced_by_model[model]["errors"].items())[0]
|
||||
line = f"| {model} | {count} | {error[:60]} | {_count} |"
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Where to store the downloaded artifacts and other result files.",
|
||||
)
|
||||
parser.add_argument("--token", default=None, type=str, help="A token that has actions:read permission.")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
_job_links = get_job_links(args.workflow_run_id, token=args.token)
|
||||
job_links = {}
|
||||
# To deal with `workflow_call` event, where a job name is the combination of the job names in the caller and callee.
|
||||
# For example, `PyTorch 1.11 / Model tests (models/albert, single-gpu)`.
|
||||
if _job_links:
|
||||
for k, v in _job_links.items():
|
||||
# This is how GitHub actions combine job names.
|
||||
if " / " in k:
|
||||
index = k.find(" / ")
|
||||
k = k[index + len(" / ") :]
|
||||
job_links[k] = v
|
||||
with open(os.path.join(args.output_dir, "job_links.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(job_links, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
artifacts = get_artifacts_links(args.workflow_run_id, token=args.token)
|
||||
with open(os.path.join(args.output_dir, "artifacts.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(artifacts, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
for idx, (name, url) in enumerate(artifacts.items()):
|
||||
download_artifact(name, url, args.output_dir, args.token)
|
||||
# Be gentle to GitHub
|
||||
time.sleep(1)
|
||||
|
||||
errors = get_all_errors(args.output_dir, job_links=job_links)
|
||||
|
||||
# `e[1]` is the error
|
||||
counter = Counter()
|
||||
counter.update([e[1] for e in errors])
|
||||
|
||||
# print the top 30 most common test errors
|
||||
most_common = counter.most_common(30)
|
||||
for item in most_common:
|
||||
print(item)
|
||||
|
||||
with open(os.path.join(args.output_dir, "errors.json"), "w", encoding="UTF-8") as fp:
|
||||
json.dump(errors, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
reduced_by_error = reduce_by_error(errors)
|
||||
reduced_by_model = reduce_by_model(errors)
|
||||
|
||||
s1 = make_github_table(reduced_by_error)
|
||||
s2 = make_github_table_per_model(reduced_by_model)
|
||||
|
||||
with open(os.path.join(args.output_dir, "reduced_by_error.txt"), "w", encoding="UTF-8") as fp:
|
||||
fp.write(s1)
|
||||
with open(os.path.join(args.output_dir, "reduced_by_model.txt"), "w", encoding="UTF-8") as fp:
|
||||
fp.write(s2)
|
||||
71
transformers/utils/get_github_job_time.py
Normal file
71
transformers/utils/get_github_job_time.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import math
|
||||
import traceback
|
||||
|
||||
import dateutil.parser as date_parser
|
||||
import requests
|
||||
|
||||
|
||||
def extract_time_from_single_job(job):
|
||||
"""Extract time info from a single job in a GitHub Actions workflow run"""
|
||||
|
||||
job_info = {}
|
||||
|
||||
start = job["started_at"]
|
||||
end = job["completed_at"]
|
||||
|
||||
start_datetime = date_parser.parse(start)
|
||||
end_datetime = date_parser.parse(end)
|
||||
|
||||
duration_in_min = round((end_datetime - start_datetime).total_seconds() / 60.0)
|
||||
|
||||
job_info["started_at"] = start
|
||||
job_info["completed_at"] = end
|
||||
job_info["duration"] = duration_in_min
|
||||
|
||||
return job_info
|
||||
|
||||
|
||||
def get_job_time(workflow_run_id, token=None):
|
||||
"""Extract time info for all jobs in a GitHub Actions workflow run"""
|
||||
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}/jobs?per_page=100"
|
||||
result = requests.get(url, headers=headers).json()
|
||||
job_time = {}
|
||||
|
||||
try:
|
||||
job_time.update({job["name"]: extract_time_from_single_job(job) for job in result["jobs"]})
|
||||
pages_to_iterate_over = math.ceil((result["total_count"] - 100) / 100)
|
||||
|
||||
for i in range(pages_to_iterate_over):
|
||||
result = requests.get(url + f"&page={i + 2}", headers=headers).json()
|
||||
job_time.update({job["name"]: extract_time_from_single_job(job) for job in result["jobs"]})
|
||||
|
||||
return job_time
|
||||
except Exception:
|
||||
print(f"Unknown error, could not fetch links:\n{traceback.format_exc()}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
r"""
|
||||
Example:
|
||||
|
||||
python get_github_job_time.py --workflow_run_id 2945609517
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# Required parameters
|
||||
parser.add_argument("--workflow_run_id", type=str, required=True, help="A GitHub Actions workflow run id.")
|
||||
args = parser.parse_args()
|
||||
|
||||
job_time = get_job_time(args.workflow_run_id)
|
||||
job_time = dict(sorted(job_time.items(), key=lambda item: item[1]["duration"], reverse=True))
|
||||
|
||||
for k, v in job_time.items():
|
||||
print(f"{k}: {v['duration']}")
|
||||
36
transformers/utils/get_modified_files.py
Normal file
36
transformers/utils/get_modified_files.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.:
|
||||
# python ./utils/get_modified_files.py utils src tests examples
|
||||
#
|
||||
# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered
|
||||
# since the output of this script is fed into Makefile commands it doesn't print a newline after the results
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8")
|
||||
modified_files = (
|
||||
subprocess.check_output(f"git diff --diff-filter=d --name-only {fork_point_sha}".split()).decode("utf-8").split()
|
||||
)
|
||||
|
||||
joined_dirs = "|".join(sys.argv[1:])
|
||||
regex = re.compile(rf"^({joined_dirs}).*?\.py$")
|
||||
|
||||
relevant_modified_files = [x for x in modified_files if regex.match(x)]
|
||||
print(" ".join(relevant_modified_files), end="")
|
||||
133
transformers/utils/get_pr_run_slow_jobs.py
Normal file
133
transformers/utils/get_pr_run_slow_jobs.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
MAX_NUM_JOBS_TO_SUGGEST = 16
|
||||
|
||||
|
||||
def get_jobs_to_run():
|
||||
# The file `pr_files.txt` contains the information about the files changed in a pull request, and it is prepared by
|
||||
# the caller (using GitHub api).
|
||||
# We can also use the following api to get the information if we don't have them before calling this script.
|
||||
# url = f"https://api.github.com/repos/huggingface/transformers/pulls/PULL_NUMBER/files?ref={pr_sha}"
|
||||
with open("pr_files.txt") as fp:
|
||||
pr_files = json.load(fp)
|
||||
pr_files = [{k: v for k, v in item.items() if k in ["filename", "status"]} for item in pr_files]
|
||||
pr_files = [item["filename"] for item in pr_files if item["status"] in ["added", "modified"]]
|
||||
|
||||
# models or quantizers
|
||||
re_1 = re.compile(r"src/transformers/(models/.*)/modeling_.*\.py")
|
||||
re_2 = re.compile(r"src/transformers/(quantizers/quantizer_.*)\.py")
|
||||
|
||||
# tests for models or quantizers
|
||||
re_3 = re.compile(r"tests/(models/.*)/test_.*\.py")
|
||||
re_4 = re.compile(r"tests/(quantization/.*)/test_.*\.py")
|
||||
|
||||
# files in a model directory but not necessary a modeling file
|
||||
re_5 = re.compile(r"src/transformers/(models/.*)/.*\.py")
|
||||
|
||||
regexes = [re_1, re_2, re_3, re_4, re_5]
|
||||
|
||||
jobs_to_run = []
|
||||
for pr_file in pr_files:
|
||||
for regex in regexes:
|
||||
matched = regex.findall(pr_file)
|
||||
if len(matched) > 0:
|
||||
item = matched[0]
|
||||
item = item.replace("quantizers/quantizer_", "quantization/")
|
||||
# TODO: for files in `quantizers`, the processed item above may not exist. Try using a fuzzy matching
|
||||
if item in repo_content:
|
||||
jobs_to_run.append(item)
|
||||
break
|
||||
jobs_to_run = sorted(set(jobs_to_run))
|
||||
|
||||
return jobs_to_run
|
||||
|
||||
|
||||
def parse_message(message: str) -> str:
|
||||
"""
|
||||
Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
|
||||
|
||||
Args:
|
||||
message (`str`): The body of a GitHub pull request's comment.
|
||||
|
||||
Returns:
|
||||
`str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
|
||||
empty string is returned.
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
message = message.strip().lower()
|
||||
|
||||
# run-slow: model_1, model_2, quantization_1, quantization_2
|
||||
if not message.startswith(("run-slow", "run_slow", "run slow")):
|
||||
return ""
|
||||
message = message[len("run slow") :]
|
||||
# remove leading `:`
|
||||
while message.strip().startswith(":"):
|
||||
message = message.strip()[1:]
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_jobs(message: str):
|
||||
models = parse_message(message)
|
||||
return models.replace(",", " ").split()
|
||||
|
||||
|
||||
def check_name(model_name: str):
|
||||
allowed = string.ascii_letters + string.digits + "_"
|
||||
return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--message", type=str, default="", help="The content of a comment.")
|
||||
parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The files are prepared by the caller (using GitHub api).
|
||||
# We can also use the following api to get the information if we don't have them before calling this script.
|
||||
# url = f"https://api.github.com/repos/OWNER/REPO/contents/PATH?ref={pr_sha}"
|
||||
# (we avoid to checkout the repository using `actions/checkout` to reduce the run time, but mostly to avoid the potential security issue as much as possible)
|
||||
repo_content = []
|
||||
for filename in ["tests_dir.txt", "tests_models_dir.txt", "tests_quantization_dir.txt"]:
|
||||
with open(filename) as fp:
|
||||
data = json.load(fp)
|
||||
data = [item["path"][len("tests/") :] for item in data if item["type"] == "dir"]
|
||||
repo_content.extend(data)
|
||||
|
||||
# These don't have the prefix `models/` or `quantization/`, so we need to add them.
|
||||
if args.message:
|
||||
specified_jobs = get_jobs(args.message)
|
||||
specified_jobs = [job for job in specified_jobs if check_name(job)]
|
||||
|
||||
# Add prefix (`models/` or `quantization`)
|
||||
jobs_to_run = []
|
||||
for job in specified_jobs:
|
||||
if not args.quantization:
|
||||
if f"models/{job}" in repo_content:
|
||||
jobs_to_run.append(f"models/{job}")
|
||||
elif job in repo_content and job != "quantization":
|
||||
jobs_to_run.append(job)
|
||||
elif f"quantization/{job}" in repo_content:
|
||||
jobs_to_run.append(f"quantization/{job}")
|
||||
|
||||
print(sorted(set(jobs_to_run)))
|
||||
|
||||
else:
|
||||
# Compute (from the added/modified files) the directories under `tests/`, `tests/models/` and `tests/quantization`to run tests.
|
||||
# These are already with the prefix `models/` or `quantization/`, so we don't need to add them.
|
||||
jobs_to_run = get_jobs_to_run()
|
||||
jobs_to_run = [x.replace("models/", "").replace("quantization/", "") for x in jobs_to_run]
|
||||
jobs_to_run = [job for job in jobs_to_run if check_name(job)]
|
||||
|
||||
if len(jobs_to_run) > MAX_NUM_JOBS_TO_SUGGEST:
|
||||
jobs_to_run = jobs_to_run[:MAX_NUM_JOBS_TO_SUGGEST]
|
||||
|
||||
suggestion = f"{', '.join(jobs_to_run)}"
|
||||
|
||||
print(suggestion)
|
||||
159
transformers/utils/get_previous_daily_ci.py
Normal file
159
transformers/utils/get_previous_daily_ci.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
from get_ci_error_statistics import download_artifact, get_artifacts_links
|
||||
|
||||
|
||||
def get_daily_ci_runs(token, num_runs=7, workflow_id=None):
|
||||
"""Get the workflow runs of the scheduled (daily) CI.
|
||||
|
||||
This only selects the runs triggered by the `schedule` event on the `main` branch.
|
||||
"""
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
# The id of a workflow (not of a workflow run).
|
||||
# From a given workflow run (where we have workflow run id), we can get the workflow id by going to
|
||||
# https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}
|
||||
# and check the `workflow_id` key.
|
||||
|
||||
if not workflow_id:
|
||||
workflow_run_id = os.environ["GITHUB_RUN_ID"]
|
||||
workflow_run = requests.get(
|
||||
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}", headers=headers
|
||||
).json()
|
||||
workflow_id = workflow_run["workflow_id"]
|
||||
|
||||
url = f"https://api.github.com/repos/huggingface/transformers/actions/workflows/{workflow_id}/runs"
|
||||
# On `main` branch + event being `schedule` + not returning PRs + only `num_runs` results
|
||||
url += f"?branch=main&exclude_pull_requests=true&per_page={num_runs}"
|
||||
|
||||
result = requests.get(f"{url}&event=schedule", headers=headers).json()
|
||||
workflow_runs = result["workflow_runs"]
|
||||
if len(workflow_runs) == 0:
|
||||
result = requests.get(f"{url}&event=workflow_run", headers=headers).json()
|
||||
workflow_runs = result["workflow_runs"]
|
||||
|
||||
return workflow_runs
|
||||
|
||||
|
||||
def get_last_daily_ci_run(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
|
||||
"""Get the last completed workflow run id of the scheduled (daily) CI."""
|
||||
headers = None
|
||||
if token is not None:
|
||||
headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {token}"}
|
||||
|
||||
workflow_run = None
|
||||
if workflow_run_id is not None and workflow_run_id != "":
|
||||
workflow_run = requests.get(
|
||||
f"https://api.github.com/repos/huggingface/transformers/actions/runs/{workflow_run_id}", headers=headers
|
||||
).json()
|
||||
return workflow_run
|
||||
|
||||
workflow_runs = get_daily_ci_runs(token, workflow_id=workflow_id)
|
||||
for run in workflow_runs:
|
||||
if commit_sha in [None, ""] and run["status"] == "completed":
|
||||
workflow_run = run
|
||||
break
|
||||
# if `commit_sha` is specified, return the latest completed run with `workflow_run["head_sha"]` matching the specified sha.
|
||||
elif commit_sha not in [None, ""] and run["head_sha"] == commit_sha and run["status"] == "completed":
|
||||
workflow_run = run
|
||||
break
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
def get_last_daily_ci_workflow_run_id(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
|
||||
"""Get the last completed workflow run id of the scheduled (daily) CI."""
|
||||
if workflow_run_id is not None and workflow_run_id != "":
|
||||
return workflow_run_id
|
||||
|
||||
workflow_run = get_last_daily_ci_run(token, workflow_id=workflow_id, commit_sha=commit_sha)
|
||||
workflow_run_id = None
|
||||
if workflow_run is not None:
|
||||
workflow_run_id = workflow_run["id"]
|
||||
|
||||
return workflow_run_id
|
||||
|
||||
|
||||
def get_last_daily_ci_run_commit(token, workflow_run_id=None, workflow_id=None, commit_sha=None):
|
||||
"""Get the commit sha of the last completed scheduled daily CI workflow run."""
|
||||
workflow_run = get_last_daily_ci_run(
|
||||
token, workflow_run_id=workflow_run_id, workflow_id=workflow_id, commit_sha=commit_sha
|
||||
)
|
||||
workflow_run_head_sha = None
|
||||
if workflow_run is not None:
|
||||
workflow_run_head_sha = workflow_run["head_sha"]
|
||||
|
||||
return workflow_run_head_sha
|
||||
|
||||
|
||||
def get_last_daily_ci_artifacts(
|
||||
output_dir,
|
||||
token,
|
||||
workflow_run_id=None,
|
||||
workflow_id=None,
|
||||
commit_sha=None,
|
||||
artifact_names=None,
|
||||
):
|
||||
"""Get the artifacts of last completed workflow run id of the scheduled (daily) CI."""
|
||||
workflow_run_id = get_last_daily_ci_workflow_run_id(
|
||||
token, workflow_run_id=workflow_run_id, workflow_id=workflow_id, commit_sha=commit_sha
|
||||
)
|
||||
if workflow_run_id is not None:
|
||||
artifacts_links = get_artifacts_links(workflow_run_id=workflow_run_id, token=token)
|
||||
|
||||
if artifact_names is None:
|
||||
artifact_names = artifacts_links.keys()
|
||||
|
||||
downloaded_artifact_names = []
|
||||
for artifact_name in artifact_names:
|
||||
if artifact_name in artifacts_links:
|
||||
artifact_url = artifacts_links[artifact_name]
|
||||
download_artifact(
|
||||
artifact_name=artifact_name, artifact_url=artifact_url, output_dir=output_dir, token=token
|
||||
)
|
||||
downloaded_artifact_names.append(artifact_name)
|
||||
|
||||
return downloaded_artifact_names
|
||||
|
||||
|
||||
def get_last_daily_ci_reports(
|
||||
output_dir,
|
||||
token,
|
||||
workflow_run_id=None,
|
||||
workflow_id=None,
|
||||
commit_sha=None,
|
||||
artifact_names=None,
|
||||
):
|
||||
"""Get the artifacts' content of the last completed workflow run id of the scheduled (daily) CI."""
|
||||
downloaded_artifact_names = get_last_daily_ci_artifacts(
|
||||
output_dir,
|
||||
token,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_id=workflow_id,
|
||||
commit_sha=commit_sha,
|
||||
artifact_names=artifact_names,
|
||||
)
|
||||
|
||||
results = {}
|
||||
for artifact_name in downloaded_artifact_names:
|
||||
artifact_zip_path = os.path.join(output_dir, f"{artifact_name}.zip")
|
||||
if os.path.isfile(artifact_zip_path):
|
||||
target_dir = os.path.join(output_dir, artifact_name)
|
||||
with zipfile.ZipFile(artifact_zip_path) as z:
|
||||
z.extractall(target_dir)
|
||||
|
||||
results[artifact_name] = {}
|
||||
filename = os.listdir(target_dir)
|
||||
for filename in filename:
|
||||
file_path = os.path.join(target_dir, filename)
|
||||
if not os.path.isdir(file_path):
|
||||
# read the file
|
||||
with open(file_path) as fp:
|
||||
content = fp.read()
|
||||
results[artifact_name][filename] = content
|
||||
|
||||
return results
|
||||
197
transformers/utils/get_test_info.py
Normal file
197
transformers/utils/get_test_info.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
# This is required to make the module import works (when the python process is running from the root of the repo)
|
||||
sys.path.append(".")
|
||||
|
||||
|
||||
r"""
|
||||
The argument `test_file` in this file refers to a model test file. This should be a string of the from
|
||||
`tests/models/*/test_modeling_*.py`.
|
||||
"""
|
||||
|
||||
|
||||
def get_module_path(test_file):
|
||||
"""Return the module path of a model test file."""
|
||||
components = test_file.split(os.path.sep)
|
||||
if components[0:2] != ["tests", "models"]:
|
||||
raise ValueError(
|
||||
"`test_file` should start with `tests/models/` (with `/` being the OS specific path separator). Got "
|
||||
f"{test_file} instead."
|
||||
)
|
||||
test_fn = components[-1]
|
||||
if not test_fn.endswith("py"):
|
||||
raise ValueError(f"`test_file` should be a python file. Got {test_fn} instead.")
|
||||
if not test_fn.startswith("test_modeling_"):
|
||||
raise ValueError(
|
||||
f"`test_file` should point to a file name of the form `test_modeling_*.py`. Got {test_fn} instead."
|
||||
)
|
||||
|
||||
components = components[:-1] + [test_fn.replace(".py", "")]
|
||||
test_module_path = ".".join(components)
|
||||
|
||||
return test_module_path
|
||||
|
||||
|
||||
def get_test_module(test_file):
|
||||
"""Get the module of a model test file."""
|
||||
test_module_path = get_module_path(test_file)
|
||||
try:
|
||||
test_module = importlib.import_module(test_module_path)
|
||||
except AttributeError as exc:
|
||||
# e.g. if you have a `tests` folder in `site-packages`, created by another package, when trying to import
|
||||
# `tests.models...`
|
||||
raise ValueError(
|
||||
f"Could not import module {test_module_path}. Confirm that you don't have a package with the same root "
|
||||
"name installed or in your environment's `site-packages`."
|
||||
) from exc
|
||||
|
||||
return test_module
|
||||
|
||||
|
||||
def get_tester_classes(test_file):
|
||||
"""Get all classes in a model test file whose names ends with `ModelTester`."""
|
||||
tester_classes = []
|
||||
test_module = get_test_module(test_file)
|
||||
for attr in dir(test_module):
|
||||
if attr.endswith("ModelTester"):
|
||||
tester_classes.append(getattr(test_module, attr))
|
||||
|
||||
# sort with class names
|
||||
return sorted(tester_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_test_classes(test_file):
|
||||
"""Get all [test] classes in a model test file with attribute `all_model_classes` that are non-empty.
|
||||
|
||||
These are usually the (model) test classes containing the (non-slow) tests to run and are subclasses of
|
||||
`ModelTesterMixin`, as well as a subclass of `unittest.TestCase`. Exceptions include `RagTestMixin` (and its subclasses).
|
||||
"""
|
||||
test_classes = []
|
||||
test_module = get_test_module(test_file)
|
||||
for attr in dir(test_module):
|
||||
attr_value = getattr(test_module, attr)
|
||||
# ModelTesterMixin is also an attribute in specific model test module. Let's exclude them by checking
|
||||
# `all_model_classes` is not empty (which also excludes other special classes).
|
||||
model_classes = getattr(attr_value, "all_model_classes", [])
|
||||
if len(model_classes) > 0:
|
||||
test_classes.append(attr_value)
|
||||
|
||||
# sort with class names
|
||||
return sorted(test_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_model_classes(test_file):
|
||||
"""Get all model classes that appear in `all_model_classes` attributes in a model test file."""
|
||||
test_classes = get_test_classes(test_file)
|
||||
model_classes = set()
|
||||
for test_class in test_classes:
|
||||
model_classes.update(test_class.all_model_classes)
|
||||
|
||||
# sort with class names
|
||||
return sorted(model_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_model_tester_from_test_class(test_class):
|
||||
"""Get the model tester class of a model test class."""
|
||||
test = test_class()
|
||||
if hasattr(test, "setUp"):
|
||||
test.setUp()
|
||||
|
||||
model_tester = None
|
||||
if hasattr(test, "model_tester"):
|
||||
# `ModelTesterMixin` has this attribute default to `None`. Let's skip this case.
|
||||
if test.model_tester is not None:
|
||||
model_tester = test.model_tester.__class__
|
||||
|
||||
return model_tester
|
||||
|
||||
|
||||
def get_test_classes_for_model(test_file, model_class):
|
||||
"""Get all [test] classes in `test_file` that have `model_class` in their `all_model_classes`."""
|
||||
test_classes = get_test_classes(test_file)
|
||||
|
||||
target_test_classes = []
|
||||
for test_class in test_classes:
|
||||
if model_class in test_class.all_model_classes:
|
||||
target_test_classes.append(test_class)
|
||||
|
||||
# sort with class names
|
||||
return sorted(target_test_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_tester_classes_for_model(test_file, model_class):
|
||||
"""Get all model tester classes in `test_file` that are associated to `model_class`."""
|
||||
test_classes = get_test_classes_for_model(test_file, model_class)
|
||||
|
||||
tester_classes = []
|
||||
for test_class in test_classes:
|
||||
tester_class = get_model_tester_from_test_class(test_class)
|
||||
if tester_class is not None:
|
||||
tester_classes.append(tester_class)
|
||||
|
||||
# sort with class names
|
||||
return sorted(tester_classes, key=lambda x: x.__name__)
|
||||
|
||||
|
||||
def get_test_to_tester_mapping(test_file):
|
||||
"""Get a mapping from [test] classes to model tester classes in `test_file`.
|
||||
|
||||
This uses `get_test_classes` which may return classes that are NOT subclasses of `unittest.TestCase`.
|
||||
"""
|
||||
test_classes = get_test_classes(test_file)
|
||||
test_tester_mapping = {test_class: get_model_tester_from_test_class(test_class) for test_class in test_classes}
|
||||
return test_tester_mapping
|
||||
|
||||
|
||||
def get_model_to_test_mapping(test_file):
|
||||
"""Get a mapping from model classes to test classes in `test_file`."""
|
||||
model_classes = get_model_classes(test_file)
|
||||
model_test_mapping = {
|
||||
model_class: get_test_classes_for_model(test_file, model_class) for model_class in model_classes
|
||||
}
|
||||
return model_test_mapping
|
||||
|
||||
|
||||
def get_model_to_tester_mapping(test_file):
|
||||
"""Get a mapping from model classes to model tester classes in `test_file`."""
|
||||
model_classes = get_model_classes(test_file)
|
||||
model_to_tester_mapping = {
|
||||
model_class: get_tester_classes_for_model(test_file, model_class) for model_class in model_classes
|
||||
}
|
||||
return model_to_tester_mapping
|
||||
|
||||
|
||||
def to_json(o):
|
||||
"""Make the information succinct and easy to read.
|
||||
|
||||
Avoid the full class representation like `<class 'transformers.models.bert.modeling_bert.BertForMaskedLM'>` when
|
||||
displaying the results. Instead, we use class name (`BertForMaskedLM`) for the readability.
|
||||
"""
|
||||
if isinstance(o, str):
|
||||
return o
|
||||
elif isinstance(o, type):
|
||||
return o.__name__
|
||||
elif isinstance(o, (list, tuple)):
|
||||
return [to_json(x) for x in o]
|
||||
elif isinstance(o, dict):
|
||||
return {to_json(k): to_json(v) for k, v in o.items()}
|
||||
else:
|
||||
return o
|
||||
272
transformers/utils/get_test_reports.py
Normal file
272
transformers/utils/get_test_reports.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This util provides a way to manually run the tests of the transformers repo as they would be run by the CI.
|
||||
It was mainly used for models tests, so if you find features missing for another suite, do not hesitate to open a PR.
|
||||
|
||||
Functionnalities:
|
||||
- Running specific test suite (models, tokenizers, etc.)
|
||||
- Parallel execution across multiple processes (each has to be launched separately with different `--processes` argument)
|
||||
- GPU/CPU test filtering and slow tests filter
|
||||
- Temporary cache management for isolated test runs
|
||||
- Resume functionality for interrupted test runs
|
||||
- Important models subset testing
|
||||
|
||||
Example usages are below.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .important_files import IMPORTANT_MODELS
|
||||
|
||||
|
||||
def is_valid_test_dir(path: Path) -> bool:
|
||||
"""Check if a given path represents a valid test dir: the path must point to a dir, not start with '__' or '.'"""
|
||||
return path.is_dir() and not path.name.startswith("__") and not path.name.startswith(".")
|
||||
|
||||
|
||||
def run_pytest(
|
||||
suite: str, subdir: Path, root_test_dir: Path, machine_type: str, dry_run: bool, tmp_cache: str, cpu_tests: bool
|
||||
) -> None:
|
||||
"""
|
||||
Execute pytest on a specific test directory with configured options:
|
||||
- suite (str): name of the test suite being run (e.g., 'models', 'tokenizers')
|
||||
- subdir (Path): the specific directory containing tests to run
|
||||
- root_test_dir (Path): the root directory of all tests, used for relative paths
|
||||
- machine_type (str): type of machine/environment (e.g., 'cpu', 'single-gpu', 'multi-gpu')
|
||||
- dry_run (bool): if True, only print the command without executing it
|
||||
- tmp_cache (str): prefix for temporary cache directory. If empty, no temp cache is used
|
||||
- cpu_tests (bool): if True, include CPU-only tests; if False, exclude non-device tests
|
||||
"""
|
||||
relative_path = subdir.relative_to(root_test_dir)
|
||||
report_name = f"{machine_type}_{suite}_{relative_path}_test_reports"
|
||||
print(f"Suite: {suite} | Running on: {relative_path}")
|
||||
|
||||
cmd = ["python3", "-m", "pytest", "-rsfE", "-v", f"--make-reports={report_name}", str(subdir)]
|
||||
if not cpu_tests:
|
||||
cmd = cmd + ["-m", "not not_device_test"]
|
||||
|
||||
ctx_manager = tempfile.TemporaryDirectory(prefix=tmp_cache) if tmp_cache else contextlib.nullcontext()
|
||||
with ctx_manager as tmp_dir:
|
||||
env = os.environ.copy()
|
||||
if tmp_cache:
|
||||
env["HUGGINGFACE_HUB_CACHE"] = tmp_dir
|
||||
|
||||
print(f"Using temporary cache located at {tmp_dir = }")
|
||||
|
||||
print("Command:", " ".join(cmd))
|
||||
if not dry_run:
|
||||
subprocess.run(cmd, check=False, env=env)
|
||||
|
||||
|
||||
def handle_suite(
|
||||
suite: str,
|
||||
test_root: Path,
|
||||
machine_type: str,
|
||||
dry_run: bool,
|
||||
tmp_cache: str = "",
|
||||
resume_at: Optional[str] = None,
|
||||
only_in: Optional[list[str]] = None,
|
||||
cpu_tests: bool = False,
|
||||
process_id: int = 1,
|
||||
total_processes: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Handle execution of a complete test suite with advanced filtering and process distribution.
|
||||
Args:
|
||||
- suite (str): Name of the test suite to run (corresponds to a directory under test_root).
|
||||
- test_root (Path): Root directory containing all test suites.
|
||||
- machine_type (str): Machine/environment type for report naming and identification.
|
||||
- dry_run (bool): If True, only print commands without executing them.
|
||||
- tmp_cache (str, optional): Prefix for temporary cache directories. If empty, no temp cache is used.
|
||||
- resume_at (str, optional): Resume execution starting from this subdirectory name.
|
||||
Useful for restarting interrupted test runs. Defaults to None (run from the beginning).
|
||||
- only_in (list[str], optional): Only run tests in these specific subdirectories.
|
||||
Can include special values like IMPORTANT_MODELS. Defaults to None (run all tests).
|
||||
- cpu_tests (bool, optional): Whether to include CPU-only tests. Defaults to False.
|
||||
- process_id (int, optional): Current process ID for parallel execution (1-indexed). Defaults to 1.
|
||||
- total_processes (int, optional): Total number of parallel processes. Defaults to 1.
|
||||
"""
|
||||
# Check path to suite
|
||||
full_path = test_root / suite
|
||||
if not full_path.exists():
|
||||
print(f"Test folder does not exist: {full_path}")
|
||||
return
|
||||
|
||||
# Establish the list of subdir to go through
|
||||
subdirs = sorted(full_path.iterdir())
|
||||
subdirs = [s for s in subdirs if is_valid_test_dir(s)]
|
||||
if resume_at is not None:
|
||||
subdirs = [s for s in subdirs if s.name >= resume_at]
|
||||
if only_in is not None:
|
||||
subdirs = [s for s in subdirs if s.name in only_in]
|
||||
if subdirs and total_processes > 1:
|
||||
# This interleaves the subdirs / files. For instance for subdirs = [A, B, C, D, E] and 2 processes:
|
||||
# - script launcehd with `--processes 0 2` will run A, C, E
|
||||
# - script launcehd with `--processes 1 2` will run B, D
|
||||
subdirs = subdirs[process_id::total_processes]
|
||||
|
||||
# If the subdir list is not empty, go through each
|
||||
if subdirs:
|
||||
for subdir in subdirs:
|
||||
run_pytest(suite, subdir, test_root, machine_type, dry_run, tmp_cache, cpu_tests)
|
||||
# Otherwise, launch pytest from the full path
|
||||
else:
|
||||
run_pytest(suite, full_path, test_root, machine_type, dry_run, tmp_cache, cpu_tests)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Command-line interface for running test suite with comprehensive reporting. Check handle_suite for more details.
|
||||
|
||||
Command-line Arguments:
|
||||
folder: Path to the root test directory (required)
|
||||
--suite: Test suite name to run (default: "models")
|
||||
--cpu-tests: Include CPU-only tests in addition to device tests
|
||||
--run-slow: Execute slow tests instead of skipping them
|
||||
--resume-at: Resume execution from a specific subdirectory
|
||||
--only-in: Run tests only in specified subdirectories (supports IMPORTANT_MODELS)
|
||||
--processes: Process distribution as "process_id total_processes"
|
||||
--dry-run: Print commands without executing them
|
||||
--tmp-cache: Use temporary cache directories for isolated runs
|
||||
--machine-type: Override automatic machine type detection
|
||||
|
||||
Machine Type Detection:
|
||||
- 'cpu': No CUDA available
|
||||
- 'single-gpu': CUDA available with 1 GPU
|
||||
- 'multi-gpu': CUDA available with multiple GPUs
|
||||
|
||||
Process Distribution:
|
||||
Use --processes to split work across multiple parallel processes:
|
||||
--processes 0 4 # This is process 0 of 4 total processes
|
||||
--processes 1 4 # This is process 1 of 4 total processes
|
||||
...
|
||||
|
||||
Usage Examples:
|
||||
# Basic model testing
|
||||
python3 -m utils.get_test_reports tests/ --suite models
|
||||
|
||||
# Run slow tests for important models only
|
||||
python3 -m utils.get_test_reports tests/ --suite models --run-slow --only-in IMPORTANT_MODELS
|
||||
|
||||
# Parallel execution across 4 processes, second process to launch (processes are 0-indexed)
|
||||
python3 -m utils.get_test_reports tests/ --suite models --processes 1 4
|
||||
|
||||
# Resume interrupted run from 'bert' subdirectory with a tmp cache
|
||||
python3 -m utils.get_test_reports tests/ --suite models --resume-at bert --tmp-cache /tmp/
|
||||
|
||||
# Run specific models with CPU tests
|
||||
python3 -m utils.get_test_reports tests/ --suite models --only-in bert gpt2 --cpu-tests
|
||||
|
||||
# Run slow tests for only important models with a tmp cache
|
||||
python3 -m utils.get_test_reports tests/ --suite models --run-slow --only-in IMPORTANT_MODELS --tmp-cache /tmp/
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("folder", help="Path to test root folder (e.g., ./tests)")
|
||||
|
||||
# Choose which tests to run (broad picture)
|
||||
parser.add_argument("--suite", type=str, default="models", help="Test suit to run")
|
||||
parser.add_argument("--cpu-tests", action="store_true", help="Also runs non-device tests")
|
||||
parser.add_argument("--run-slow", action="store_true", help="Run slow tests instead of skipping them")
|
||||
parser.add_argument("--collect-outputs", action="store_true", help="Collect outputs of the tests")
|
||||
|
||||
# Fine-grain control over the tests to run
|
||||
parser.add_argument("--resume-at", type=str, default=None, help="Resume at a specific subdir / file in the suite")
|
||||
parser.add_argument(
|
||||
"--only-in",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Only run tests in the given subdirs / file. Use IMPORTANT_MODELS to run only the important models tests.",
|
||||
)
|
||||
|
||||
# How to run the test suite: is the work divided among processes, do a try run, use temp cache?
|
||||
parser.add_argument(
|
||||
"--processes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Inform each CI process as to the work to do: format as `process_id total_processes`. "
|
||||
"In order to run with multiple (eg. 3) processes, you need to run the script multiple times (eg. 3 times).",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Only print commands without running them")
|
||||
parser.add_argument("--tmp-cache", type=str, help="Change HUGGINGFACE_HUB_CACHE to a tmp dir for each test")
|
||||
|
||||
# This is a purely decorative argument, but it can be useful to distinguish between runs
|
||||
parser.add_argument(
|
||||
"--machine-type", type=str, default="", help="Machine type, automatically inferred if not provided"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle run slow
|
||||
if args.run_slow:
|
||||
os.environ["RUN_SLOW"] = "yes"
|
||||
print("[WARNING] Running slow tests.")
|
||||
else:
|
||||
print("[WARNING] Skipping slow tests.")
|
||||
|
||||
# Handle multiple CI processes
|
||||
if args.processes is None:
|
||||
process_id, total_processes = 1, 1
|
||||
elif len(args.processes) == 2:
|
||||
process_id, total_processes = args.processes
|
||||
else:
|
||||
raise ValueError(f"Invalid processes argument: {args.processes}")
|
||||
|
||||
# Assert test root exists
|
||||
test_root = Path(args.folder).resolve()
|
||||
if not test_root.exists():
|
||||
print(f"Root test folder not found: {test_root}")
|
||||
exit(1)
|
||||
|
||||
# Handle collection of outputs
|
||||
if args.collect_outputs:
|
||||
os.environ["PATCH_TESTING_METHODS_TO_COLLECT_OUTPUTS"] = "yes"
|
||||
reports_dir = test_root.parent / "reports"
|
||||
os.environ["_PATCHED_TESTING_METHODS_OUTPUT_DIR"] = str(reports_dir)
|
||||
|
||||
# Infer machine type if not provided
|
||||
if args.machine_type == "":
|
||||
if not torch.cuda.is_available():
|
||||
machine_type = "cpu"
|
||||
else:
|
||||
machine_type = "multi-gpu" if torch.cuda.device_count() > 1 else "single-gpu"
|
||||
else:
|
||||
machine_type = args.machine_type
|
||||
|
||||
# Reduce the scope for models if necessary
|
||||
only_in = args.only_in if args.only_in else None
|
||||
if only_in == ["IMPORTANT_MODELS"]:
|
||||
only_in = IMPORTANT_MODELS
|
||||
|
||||
# Launch suite
|
||||
handle_suite(
|
||||
suite=args.suite,
|
||||
test_root=test_root,
|
||||
machine_type=machine_type,
|
||||
dry_run=args.dry_run,
|
||||
tmp_cache=args.tmp_cache,
|
||||
resume_at=args.resume_at,
|
||||
only_in=only_in,
|
||||
cpu_tests=args.cpu_tests,
|
||||
process_id=process_id,
|
||||
total_processes=total_processes,
|
||||
)
|
||||
28
transformers/utils/important_files.py
Normal file
28
transformers/utils/important_files.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# List here the models to always test.
|
||||
IMPORTANT_MODELS = [
|
||||
"auto",
|
||||
"bert",
|
||||
"gpt2",
|
||||
"t5",
|
||||
"modernbert",
|
||||
"vit,clip",
|
||||
"detr",
|
||||
"table_transformer",
|
||||
"got_ocr2",
|
||||
"whisper",
|
||||
"wav2vec2",
|
||||
"qwen2_audio",
|
||||
"speech_t5",
|
||||
"csm",
|
||||
"llama",
|
||||
"gemma3",
|
||||
"qwen2",
|
||||
"mistral3",
|
||||
"qwen2_5_vl",
|
||||
"llava",
|
||||
"smolvlm",
|
||||
"internvl",
|
||||
"gemma3n",
|
||||
"gpt_oss",
|
||||
"qwen2_5_omni",
|
||||
]
|
||||
4
transformers/utils/important_models.txt
Normal file
4
transformers/utils/important_models.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
models/llama
|
||||
models/mistral
|
||||
models/mixtral
|
||||
models/gemma
|
||||
196
transformers/utils/models_to_deprecate.py
Normal file
196
transformers/utils/models_to_deprecate.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Script to find a candidate list of models to deprecate based on the number of downloads and the date of the last commit.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from git import Repo
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
repo = Repo(PATH_TO_REPO)
|
||||
|
||||
|
||||
class HubModelLister:
|
||||
"""
|
||||
Utility for getting models from the hub based on tags. Handles errors without crashing the script.
|
||||
"""
|
||||
|
||||
def __init__(self, tags):
|
||||
self.tags = tags
|
||||
self.model_list = api.list_models(tags=tags)
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
yield from self.model_list
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return
|
||||
|
||||
|
||||
def _extract_commit_hash(commits):
|
||||
for commit in commits:
|
||||
if commit.startswith("commit "):
|
||||
return commit.split(" ")[1]
|
||||
return ""
|
||||
|
||||
|
||||
def get_list_of_repo_model_paths(models_dir):
|
||||
# Get list of all models in the library
|
||||
models = glob.glob(os.path.join(models_dir, "*/modeling_*.py"))
|
||||
|
||||
# Get list of all deprecated models in the library
|
||||
deprecated_models = glob.glob(os.path.join(models_dir, "deprecated", "*"))
|
||||
# For each deprecated model, remove the deprecated models from the list of all models as well as the symlink path
|
||||
for deprecated_model in deprecated_models:
|
||||
deprecated_model_name = "/" + deprecated_model.split("/")[-1] + "/"
|
||||
models = [model for model in models if deprecated_model_name not in model]
|
||||
# Remove deprecated models
|
||||
models = [model for model in models if "/deprecated" not in model]
|
||||
# Remove auto
|
||||
models = [model for model in models if "/auto/" not in model]
|
||||
return models
|
||||
|
||||
|
||||
def get_list_of_models_to_deprecate(
|
||||
thresh_num_downloads=5_000,
|
||||
thresh_date=None,
|
||||
use_cache=False,
|
||||
save_model_info=False,
|
||||
max_num_models=-1,
|
||||
):
|
||||
if thresh_date is None:
|
||||
thresh_date = datetime.now(timezone.utc).replace(year=datetime.now(timezone.utc).year - 1)
|
||||
else:
|
||||
thresh_date = datetime.strptime(thresh_date, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||
|
||||
models_dir = PATH_TO_REPO / "src/transformers/models"
|
||||
model_paths = get_list_of_repo_model_paths(models_dir=models_dir)
|
||||
|
||||
if use_cache and os.path.exists("models_info.json"):
|
||||
with open("models_info.json", "r") as f:
|
||||
models_info = json.load(f)
|
||||
# Convert datetimes back to datetime objects
|
||||
for model, info in models_info.items():
|
||||
info["first_commit_datetime"] = datetime.fromisoformat(info["first_commit_datetime"])
|
||||
|
||||
else:
|
||||
# Build a dictionary of model info: first commit datetime, commit hash, model path
|
||||
models_info = defaultdict(dict)
|
||||
for model_path in model_paths:
|
||||
model = model_path.split("/")[-2]
|
||||
if model in models_info:
|
||||
continue
|
||||
commits = repo.git.log("--diff-filter=A", "--", model_path).split("\n")
|
||||
commit_hash = _extract_commit_hash(commits)
|
||||
commit_obj = repo.commit(commit_hash)
|
||||
committed_datetime = commit_obj.committed_datetime
|
||||
models_info[model]["commit_hash"] = commit_hash
|
||||
models_info[model]["first_commit_datetime"] = committed_datetime
|
||||
models_info[model]["model_path"] = model_path
|
||||
models_info[model]["downloads"] = 0
|
||||
|
||||
# Some tags on the hub are formatted differently than in the library
|
||||
tags = [model]
|
||||
if "_" in model:
|
||||
tags.append(model.replace("_", "-"))
|
||||
models_info[model]["tags"] = tags
|
||||
|
||||
# Filter out models which were added less than a year ago
|
||||
models_info = {
|
||||
model: info for model, info in models_info.items() if info["first_commit_datetime"] < thresh_date
|
||||
}
|
||||
|
||||
# We make successive calls to the hub, filtering based on the model tags
|
||||
n_seen = 0
|
||||
for model, model_info in models_info.items():
|
||||
for model_tag in model_info["tags"]:
|
||||
model_list = HubModelLister(tags=model_tag)
|
||||
for i, hub_model in enumerate(model_list):
|
||||
n_seen += 1
|
||||
if i % 100 == 0:
|
||||
print(f"Processing model {i} for tag {model_tag}")
|
||||
if max_num_models != -1 and i > n_seen:
|
||||
break
|
||||
if hub_model.private:
|
||||
continue
|
||||
model_info["downloads"] += hub_model.downloads
|
||||
|
||||
if save_model_info and not (use_cache and os.path.exists("models_info.json")):
|
||||
# Make datetimes serializable
|
||||
for model, info in models_info.items():
|
||||
info["first_commit_datetime"] = info["first_commit_datetime"].isoformat()
|
||||
with open("models_info.json", "w") as f:
|
||||
json.dump(models_info, f, indent=4)
|
||||
|
||||
print("\nFinding models to deprecate:")
|
||||
n_models_to_deprecate = 0
|
||||
models_to_deprecate = {}
|
||||
for model, info in models_info.items():
|
||||
n_downloads = info["downloads"]
|
||||
if n_downloads < thresh_num_downloads:
|
||||
n_models_to_deprecate += 1
|
||||
models_to_deprecate[model] = info
|
||||
print(f"\nModel: {model}")
|
||||
print(f"Downloads: {n_downloads}")
|
||||
print(f"Date: {info['first_commit_datetime']}")
|
||||
print("\nModels to deprecate: ", "\n" + "\n".join(models_to_deprecate.keys()))
|
||||
print(f"\nNumber of models to deprecate: {n_models_to_deprecate}")
|
||||
print("Before deprecating make sure to verify the models, including if they're used as a module in other models.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--save_model_info", action="store_true", help="Save the retrieved model info to a json file.")
|
||||
parser.add_argument(
|
||||
"--use_cache", action="store_true", help="Use the cached model info instead of calling the hub."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--thresh_num_downloads",
|
||||
type=int,
|
||||
default=5_000,
|
||||
help="Threshold number of downloads below which a model should be deprecated. Default is 5,000.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--thresh_date",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Date to consider the first commit from. Format: YYYY-MM-DD. If unset, defaults to one year ago from today.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_num_models",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum number of models to consider from the hub. -1 means all models. Useful for testing.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
models_to_deprecate = get_list_of_models_to_deprecate(
|
||||
thresh_num_downloads=args.thresh_num_downloads,
|
||||
thresh_date=args.thresh_date,
|
||||
use_cache=args.use_cache,
|
||||
save_model_info=args.save_model_info,
|
||||
max_num_models=args.max_num_models,
|
||||
)
|
||||
1788
transformers/utils/modular_model_converter.py
Normal file
1788
transformers/utils/modular_model_converter.py
Normal file
File diff suppressed because it is too large
Load Diff
819
transformers/utils/not_doctested.txt
Normal file
819
transformers/utils/not_doctested.txt
Normal file
@@ -0,0 +1,819 @@
|
||||
docs/source/en/_config.py
|
||||
docs/source/en/accelerate.md
|
||||
docs/source/en/add_new_model.md
|
||||
docs/source/en/add_new_pipeline.md
|
||||
docs/source/en/community.md
|
||||
docs/source/en/contributing.md
|
||||
docs/source/en/custom_models.md
|
||||
docs/source/en/debugging.md
|
||||
docs/source/en/fast_tokenizers.md
|
||||
docs/source/en/glossary.md
|
||||
docs/source/en/hpo_train.md
|
||||
docs/source/en/index.md
|
||||
docs/source/en/installation.md
|
||||
docs/source/en/internal/audio_utils.md
|
||||
docs/source/en/internal/file_utils.md
|
||||
docs/source/en/internal/image_processing_utils.md
|
||||
docs/source/en/internal/modeling_utils.md
|
||||
docs/source/en/internal/pipelines_utils.md
|
||||
docs/source/en/internal/time_series_utils.md
|
||||
docs/source/en/internal/tokenization_utils.md
|
||||
docs/source/en/internal/trainer_utils.md
|
||||
docs/source/en/llm_tutorial.md
|
||||
docs/source/en/main_classes/callback.md
|
||||
docs/source/en/main_classes/configuration.md
|
||||
docs/source/en/main_classes/data_collator.md
|
||||
docs/source/en/main_classes/deepspeed.md
|
||||
docs/source/en/main_classes/feature_extractor.md
|
||||
docs/source/en/main_classes/image_processor.md
|
||||
docs/source/en/main_classes/logging.md
|
||||
docs/source/en/main_classes/model.md
|
||||
docs/source/en/main_classes/onnx.md
|
||||
docs/source/en/main_classes/optimizer_schedules.md
|
||||
docs/source/en/main_classes/output.md
|
||||
docs/source/en/main_classes/pipelines.md
|
||||
docs/source/en/main_classes/processors.md
|
||||
docs/source/en/main_classes/quantization.md
|
||||
docs/source/en/main_classes/tokenizer.md
|
||||
docs/source/en/main_classes/trainer.md
|
||||
docs/source/en/model_doc/albert.md
|
||||
docs/source/en/model_doc/align.md
|
||||
docs/source/en/model_doc/altclip.md
|
||||
docs/source/en/model_doc/audio-spectrogram-transformer.md
|
||||
docs/source/en/model_doc/auto.md
|
||||
docs/source/en/model_doc/autoformer.md
|
||||
docs/source/en/model_doc/bark.md
|
||||
docs/source/en/model_doc/bart.md
|
||||
docs/source/en/model_doc/barthez.md
|
||||
docs/source/en/model_doc/bartpho.md
|
||||
docs/source/en/model_doc/beit.md
|
||||
docs/source/en/model_doc/bert-generation.md
|
||||
docs/source/en/model_doc/bert-japanese.md
|
||||
docs/source/en/model_doc/bert.md
|
||||
docs/source/en/model_doc/bertweet.md
|
||||
docs/source/en/model_doc/big_bird.md
|
||||
docs/source/en/model_doc/bigbird_pegasus.md
|
||||
docs/source/en/model_doc/biogpt.md
|
||||
docs/source/en/model_doc/bit.md
|
||||
docs/source/en/model_doc/blenderbot-small.md
|
||||
docs/source/en/model_doc/blenderbot.md
|
||||
docs/source/en/model_doc/blip-2.md
|
||||
docs/source/en/model_doc/blip.md
|
||||
docs/source/en/model_doc/bloom.md
|
||||
docs/source/en/model_doc/bort.md
|
||||
docs/source/en/model_doc/bridgetower.md
|
||||
docs/source/en/model_doc/camembert.md
|
||||
docs/source/en/model_doc/canine.md
|
||||
docs/source/en/model_doc/chinese_clip.md
|
||||
docs/source/en/model_doc/clap.md
|
||||
docs/source/en/model_doc/clip.md
|
||||
docs/source/en/model_doc/clipseg.md
|
||||
docs/source/en/model_doc/codegen.md
|
||||
docs/source/en/model_doc/conditional_detr.md
|
||||
docs/source/en/model_doc/convbert.md
|
||||
docs/source/en/model_doc/convnext.md
|
||||
docs/source/en/model_doc/convnextv2.md
|
||||
docs/source/en/model_doc/cpm.md
|
||||
docs/source/en/model_doc/cpmant.md
|
||||
docs/source/en/model_doc/ctrl.md
|
||||
docs/source/en/model_doc/cvt.md
|
||||
docs/source/en/model_doc/data2vec.md
|
||||
docs/source/en/model_doc/deberta-v2.md
|
||||
docs/source/en/model_doc/deberta.md
|
||||
docs/source/en/model_doc/decision_transformer.md
|
||||
docs/source/en/model_doc/deformable_detr.md
|
||||
docs/source/en/model_doc/deit.md
|
||||
docs/source/en/model_doc/deplot.md
|
||||
docs/source/en/model_doc/deta.md
|
||||
docs/source/en/model_doc/detr.md
|
||||
docs/source/en/model_doc/dialogpt.md
|
||||
docs/source/en/model_doc/dinat.md
|
||||
docs/source/en/model_doc/dinov2.md
|
||||
docs/source/en/model_doc/distilbert.md
|
||||
docs/source/en/model_doc/dit.md
|
||||
docs/source/en/model_doc/dpr.md
|
||||
docs/source/en/model_doc/dpt.md
|
||||
docs/source/en/model_doc/efficientformer.md
|
||||
docs/source/en/model_doc/efficientnet.md
|
||||
docs/source/en/model_doc/electra.md
|
||||
docs/source/en/model_doc/encodec.md
|
||||
docs/source/en/model_doc/ernie.md
|
||||
docs/source/en/model_doc/ernie_m.md
|
||||
docs/source/en/model_doc/esm.md
|
||||
docs/source/en/model_doc/flan-t5.md
|
||||
docs/source/en/model_doc/flan-ul2.md
|
||||
docs/source/en/model_doc/flaubert.md
|
||||
docs/source/en/model_doc/flava.md
|
||||
docs/source/en/model_doc/fnet.md
|
||||
docs/source/en/model_doc/focalnet.md
|
||||
docs/source/en/model_doc/fsmt.md
|
||||
docs/source/en/model_doc/funnel.md
|
||||
docs/source/en/model_doc/git.md
|
||||
docs/source/en/model_doc/glpn.md
|
||||
docs/source/en/model_doc/gpt-sw3.md
|
||||
docs/source/en/model_doc/gpt2.md
|
||||
docs/source/en/model_doc/gpt_bigcode.md
|
||||
docs/source/en/model_doc/gpt_neo.md
|
||||
docs/source/en/model_doc/gpt_neox.md
|
||||
docs/source/en/model_doc/gpt_neox_japanese.md
|
||||
docs/source/en/model_doc/gptj.md
|
||||
docs/source/en/model_doc/gptsan-japanese.md
|
||||
docs/source/en/model_doc/graphormer.md
|
||||
docs/source/en/model_doc/groupvit.md
|
||||
docs/source/en/model_doc/herbert.md
|
||||
docs/source/en/model_doc/hubert.md
|
||||
docs/source/en/model_doc/ibert.md
|
||||
docs/source/en/model_doc/idefics.md
|
||||
docs/source/en/model_doc/imagegpt.md
|
||||
docs/source/en/model_doc/informer.md
|
||||
docs/source/en/model_doc/instructblip.md
|
||||
docs/source/en/model_doc/jukebox.md
|
||||
docs/source/en/model_doc/layoutlm.md
|
||||
docs/source/en/model_doc/layoutlmv2.md
|
||||
docs/source/en/model_doc/layoutlmv3.md
|
||||
docs/source/en/model_doc/layoutxlm.md
|
||||
docs/source/en/model_doc/led.md
|
||||
docs/source/en/model_doc/levit.md
|
||||
docs/source/en/model_doc/lilt.md
|
||||
docs/source/en/model_doc/llama.md
|
||||
docs/source/en/model_doc/llama2.md
|
||||
docs/source/en/model_doc/llava.md
|
||||
docs/source/en/model_doc/llava_next.md
|
||||
docs/source/en/model_doc/longformer.md
|
||||
docs/source/en/model_doc/longt5.md
|
||||
docs/source/en/model_doc/luke.md
|
||||
docs/source/en/model_doc/lxmert.md
|
||||
docs/source/en/model_doc/m2m_100.md
|
||||
docs/source/en/model_doc/madlad-400.md
|
||||
docs/source/en/model_doc/marian.md
|
||||
docs/source/en/model_doc/mask2former.md
|
||||
docs/source/en/model_doc/maskformer.md
|
||||
docs/source/en/model_doc/matcha.md
|
||||
docs/source/en/model_doc/mbart.md
|
||||
docs/source/en/model_doc/mctct.md
|
||||
docs/source/en/model_doc/mega.md
|
||||
docs/source/en/model_doc/megatron-bert.md
|
||||
docs/source/en/model_doc/megatron_gpt2.md
|
||||
docs/source/en/model_doc/mgp-str.md
|
||||
docs/source/en/model_doc/mistral.md
|
||||
docs/source/en/model_doc/mixtral.md
|
||||
docs/source/en/model_doc/mluke.md
|
||||
docs/source/en/model_doc/mms.md
|
||||
docs/source/en/model_doc/mobilebert.md
|
||||
docs/source/en/model_doc/mobilenet_v1.md
|
||||
docs/source/en/model_doc/mobilenet_v2.md
|
||||
docs/source/en/model_doc/mobilevit.md
|
||||
docs/source/en/model_doc/mobilevitv2.md
|
||||
docs/source/en/model_doc/mpnet.md
|
||||
docs/source/en/model_doc/mpt.md
|
||||
docs/source/en/model_doc/mra.md
|
||||
docs/source/en/model_doc/mt5.md
|
||||
docs/source/en/model_doc/musicgen.md
|
||||
docs/source/en/model_doc/musicgen_melody.md
|
||||
docs/source/en/model_doc/mvp.md
|
||||
docs/source/en/model_doc/nat.md
|
||||
docs/source/en/model_doc/nezha.md
|
||||
docs/source/en/model_doc/nllb-moe.md
|
||||
docs/source/en/model_doc/nllb.md
|
||||
docs/source/en/model_doc/nystromformer.md
|
||||
docs/source/en/model_doc/oneformer.md
|
||||
docs/source/en/model_doc/open-llama.md
|
||||
docs/source/en/model_doc/openai-gpt.md
|
||||
docs/source/en/model_doc/opt.md
|
||||
docs/source/en/model_doc/owlvit.md
|
||||
docs/source/en/model_doc/pegasus.md
|
||||
docs/source/en/model_doc/pegasus_x.md
|
||||
docs/source/en/model_doc/perceiver.md
|
||||
docs/source/en/model_doc/phobert.md
|
||||
docs/source/en/model_doc/pix2struct.md
|
||||
docs/source/en/model_doc/plbart.md
|
||||
docs/source/en/model_doc/poolformer.md
|
||||
docs/source/en/model_doc/pop2piano.md
|
||||
docs/source/en/model_doc/prophetnet.md
|
||||
docs/source/en/model_doc/pvt.md
|
||||
docs/source/en/model_doc/qdqbert.md
|
||||
docs/source/en/model_doc/qwen2.md
|
||||
docs/source/en/model_doc/qwen2_moe.md
|
||||
docs/source/en/model_doc/rag.md
|
||||
docs/source/en/model_doc/realm.md
|
||||
docs/source/en/model_doc/reformer.md
|
||||
docs/source/en/model_doc/regnet.md
|
||||
docs/source/en/model_doc/rembert.md
|
||||
docs/source/en/model_doc/resnet.md
|
||||
docs/source/en/model_doc/retribert.md
|
||||
docs/source/en/model_doc/roberta-prelayernorm.md
|
||||
docs/source/en/model_doc/roberta.md
|
||||
docs/source/en/model_doc/roc_bert.md
|
||||
docs/source/en/model_doc/roformer.md
|
||||
docs/source/en/model_doc/rwkv.md
|
||||
docs/source/en/model_doc/sam.md
|
||||
docs/source/en/model_doc/sam_hq.md
|
||||
docs/source/en/model_doc/segformer.md
|
||||
docs/source/en/model_doc/sew-d.md
|
||||
docs/source/en/model_doc/sew.md
|
||||
docs/source/en/model_doc/speech-encoder-decoder.md
|
||||
docs/source/en/model_doc/speech_to_text_2.md
|
||||
docs/source/en/model_doc/speecht5.md
|
||||
docs/source/en/model_doc/splinter.md
|
||||
docs/source/en/model_doc/squeezebert.md
|
||||
docs/source/en/model_doc/swiftformer.md
|
||||
docs/source/en/model_doc/swin.md
|
||||
docs/source/en/model_doc/swin2sr.md
|
||||
docs/source/en/model_doc/swinv2.md
|
||||
docs/source/en/model_doc/table-transformer.md
|
||||
docs/source/en/model_doc/tapas.md
|
||||
docs/source/en/model_doc/time_series_transformer.md
|
||||
docs/source/en/model_doc/timesformer.md
|
||||
docs/source/en/model_doc/trajectory_transformer.md
|
||||
docs/source/en/model_doc/transfo-xl.md
|
||||
docs/source/en/model_doc/trocr.md
|
||||
docs/source/en/model_doc/tvlt.md
|
||||
docs/source/en/model_doc/ul2.md
|
||||
docs/source/en/model_doc/umt5.md
|
||||
docs/source/en/model_doc/unispeech-sat.md
|
||||
docs/source/en/model_doc/unispeech.md
|
||||
docs/source/en/model_doc/upernet.md
|
||||
docs/source/en/model_doc/van.md
|
||||
docs/source/en/model_doc/videomae.md
|
||||
docs/source/en/model_doc/vilt.md
|
||||
docs/source/en/model_doc/vipllava.md
|
||||
docs/source/en/model_doc/vision-encoder-decoder.md
|
||||
docs/source/en/model_doc/vision-text-dual-encoder.md
|
||||
docs/source/en/model_doc/visual_bert.md
|
||||
docs/source/en/model_doc/vit.md
|
||||
docs/source/en/model_doc/vit_hybrid.md
|
||||
docs/source/en/model_doc/vit_mae.md
|
||||
docs/source/en/model_doc/vit_msn.md
|
||||
docs/source/en/model_doc/vivit.md
|
||||
docs/source/en/model_doc/wav2vec2-conformer.md
|
||||
docs/source/en/model_doc/wav2vec2.md
|
||||
docs/source/en/model_doc/wav2vec2_phoneme.md
|
||||
docs/source/en/model_doc/wavlm.md
|
||||
docs/source/en/model_doc/whisper.md
|
||||
docs/source/en/model_doc/xclip.md
|
||||
docs/source/en/model_doc/xglm.md
|
||||
docs/source/en/model_doc/xlm-prophetnet.md
|
||||
docs/source/en/model_doc/xlm-roberta-xl.md
|
||||
docs/source/en/model_doc/xlm-roberta.md
|
||||
docs/source/en/model_doc/xlm-v.md
|
||||
docs/source/en/model_doc/xlm.md
|
||||
docs/source/en/model_doc/xlnet.md
|
||||
docs/source/en/model_doc/xls_r.md
|
||||
docs/source/en/model_doc/xlsr_wav2vec2.md
|
||||
docs/source/en/model_doc/xmod.md
|
||||
docs/source/en/model_doc/yolos.md
|
||||
docs/source/en/model_doc/yoso.md
|
||||
docs/source/en/model_memory_anatomy.md
|
||||
docs/source/en/model_sharing.md
|
||||
docs/source/en/notebooks.md
|
||||
docs/source/en/pad_truncation.md
|
||||
docs/source/en/peft.md
|
||||
docs/source/en/perf_hardware.md
|
||||
docs/source/en/perf_infer_cpu.md
|
||||
docs/source/en/perf_infer_gpu_one.md
|
||||
docs/source/en/perf_torch_compile.md
|
||||
docs/source/en/perf_train_cpu.md
|
||||
docs/source/en/perf_train_cpu_many.md
|
||||
docs/source/en/perf_train_gpu_many.md
|
||||
docs/source/en/perf_train_gpu_one.md
|
||||
docs/source/en/perf_train_special.md
|
||||
docs/source/en/perplexity.md
|
||||
docs/source/en/philosophy.md
|
||||
docs/source/en/pipeline_webserver.md
|
||||
docs/source/en/pr_checks.md
|
||||
docs/source/en/run_scripts.md
|
||||
docs/source/en/serialization.md
|
||||
docs/source/en/tasks/asr.md
|
||||
docs/source/en/tasks/audio_classification.md
|
||||
docs/source/en/tasks/document_question_answering.md
|
||||
docs/source/en/tasks/idefics.md
|
||||
docs/source/en/tasks/image_captioning.md
|
||||
docs/source/en/tasks/image_classification.md
|
||||
docs/source/en/tasks/language_modeling.md
|
||||
docs/source/en/tasks/masked_language_modeling.md
|
||||
docs/source/en/tasks/monocular_depth_estimation.md
|
||||
docs/source/en/tasks/multiple_choice.md
|
||||
docs/source/en/tasks/object_detection.md
|
||||
docs/source/en/tasks/question_answering.md
|
||||
docs/source/en/tasks/semantic_segmentation.md
|
||||
docs/source/en/tasks/sequence_classification.md
|
||||
docs/source/en/tasks/summarization.md
|
||||
docs/source/en/tasks/text-to-speech.md
|
||||
docs/source/en/tasks/token_classification.md
|
||||
docs/source/en/tasks/translation.md
|
||||
docs/source/en/tasks/video_classification.md
|
||||
docs/source/en/tasks/visual_question_answering.md
|
||||
docs/source/en/tasks/zero_shot_image_classification.md
|
||||
docs/source/en/tasks/zero_shot_object_detection.md
|
||||
docs/source/en/tokenizer_summary.md
|
||||
docs/source/en/torchscript.md
|
||||
docs/source/en/training.md
|
||||
docs/source/en/troubleshooting.md
|
||||
src/transformers/activations.py
|
||||
src/transformers/audio_utils.py
|
||||
src/transformers/commands/add_new_model_like.py
|
||||
src/transformers/commands/download.py
|
||||
src/transformers/commands/env.py
|
||||
src/transformers/commands/run.py
|
||||
src/transformers/commands/serving.py
|
||||
src/transformers/commands/transformers_cli.py
|
||||
src/transformers/configuration_utils.py
|
||||
src/transformers/convert_slow_tokenizer.py
|
||||
src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py
|
||||
src/transformers/data/data_collator.py
|
||||
src/transformers/data/datasets/glue.py
|
||||
src/transformers/data/datasets/language_modeling.py
|
||||
src/transformers/data/datasets/squad.py
|
||||
src/transformers/data/metrics/squad_metrics.py
|
||||
src/transformers/data/processors/glue.py
|
||||
src/transformers/data/processors/squad.py
|
||||
src/transformers/data/processors/utils.py
|
||||
src/transformers/data/processors/xnli.py
|
||||
src/transformers/debug_utils.py
|
||||
src/transformers/dependency_versions_check.py
|
||||
src/transformers/dependency_versions_table.py
|
||||
src/transformers/dynamic_module_utils.py
|
||||
src/transformers/feature_extraction_sequence_utils.py
|
||||
src/transformers/feature_extraction_utils.py
|
||||
src/transformers/file_utils.py
|
||||
src/transformers/hf_argparser.py
|
||||
src/transformers/hyperparameter_search.py
|
||||
src/transformers/image_processing_utils.py
|
||||
src/transformers/image_transforms.py
|
||||
src/transformers/image_utils.py
|
||||
src/transformers/integrations/bitsandbytes.py
|
||||
src/transformers/integrations/deepspeed.py
|
||||
src/transformers/integrations/integration_utils.py
|
||||
src/transformers/integrations/peft.py
|
||||
src/transformers/modelcard.py
|
||||
src/transformers/modeling_outputs.py
|
||||
src/transformers/modeling_utils.py
|
||||
src/transformers/models/align/configuration_align.py
|
||||
src/transformers/models/align/modeling_align.py
|
||||
src/transformers/models/altclip/configuration_altclip.py
|
||||
src/transformers/models/altclip/modeling_altclip.py
|
||||
src/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py
|
||||
src/transformers/models/audio_spectrogram_transformer/convert_audio_spectrogram_transformer_original_to_pytorch.py
|
||||
src/transformers/models/auto/auto_factory.py
|
||||
src/transformers/models/auto/configuration_auto.py
|
||||
src/transformers/models/auto/modeling_auto.py
|
||||
src/transformers/models/autoformer/configuration_autoformer.py
|
||||
src/transformers/models/autoformer/modeling_autoformer.py
|
||||
src/transformers/models/bark/convert_suno_to_hf.py
|
||||
src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/beit/convert_beit_unilm_to_pytorch.py
|
||||
src/transformers/models/bert_generation/modeling_bert_generation.py
|
||||
src/transformers/models/biogpt/configuration_biogpt.py
|
||||
src/transformers/models/biogpt/convert_biogpt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/biogpt/modeling_biogpt.py
|
||||
src/transformers/models/bit/configuration_bit.py
|
||||
src/transformers/models/bit/convert_bit_to_pytorch.py
|
||||
src/transformers/models/bit/modeling_bit.py
|
||||
src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/blip/configuration_blip.py
|
||||
src/transformers/models/blip/convert_blip_original_pytorch_to_hf.py
|
||||
src/transformers/models/blip/modeling_blip_text.py
|
||||
src/transformers/models/blip_2/configuration_blip_2.py
|
||||
src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py
|
||||
src/transformers/models/blip_2/modeling_blip_2.py
|
||||
src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/bloom/modeling_bloom.py
|
||||
src/transformers/models/bridgetower/configuration_bridgetower.py
|
||||
src/transformers/models/bridgetower/modeling_bridgetower.py
|
||||
src/transformers/models/bros/convert_bros_to_pytorch.py
|
||||
src/transformers/models/camembert/modeling_camembert.py
|
||||
src/transformers/models/chinese_clip/configuration_chinese_clip.py
|
||||
src/transformers/models/chinese_clip/convert_chinese_clip_original_pytorch_to_hf.py
|
||||
src/transformers/models/chinese_clip/modeling_chinese_clip.py
|
||||
src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py
|
||||
src/transformers/models/clip/convert_clip_original_pytorch_to_hf.py
|
||||
src/transformers/models/clip/modeling_clip.py
|
||||
src/transformers/models/clipseg/configuration_clipseg.py
|
||||
src/transformers/models/clipseg/convert_clipseg_original_pytorch_to_hf.py
|
||||
src/transformers/models/codegen/modeling_codegen.py
|
||||
src/transformers/models/conditional_detr/convert_conditional_detr_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/convbert/modeling_convbert.py
|
||||
src/transformers/models/convnext/convert_convnext_to_pytorch.py
|
||||
src/transformers/models/convnextv2/configuration_convnextv2.py
|
||||
src/transformers/models/convnextv2/convert_convnextv2_to_pytorch.py
|
||||
src/transformers/models/convnextv2/modeling_convnextv2.py
|
||||
src/transformers/models/cpmant/configuration_cpmant.py
|
||||
src/transformers/models/cpmant/modeling_cpmant.py
|
||||
src/transformers/models/cpmant/tokenization_cpmant.py
|
||||
src/transformers/models/cvt/convert_cvt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/convert_data2vec_audio_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/convert_data2vec_text_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/convert_data2vec_vision_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/data2vec/modeling_data2vec_text.py
|
||||
src/transformers/models/decision_transformer/modeling_decision_transformer.py
|
||||
src/transformers/models/deformable_detr/convert_deformable_detr_to_pytorch.py
|
||||
src/transformers/models/deit/convert_deit_timm_to_pytorch.py
|
||||
src/transformers/models/deprecated/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py
|
||||
src/transformers/models/deprecated/mctct/configuration_mctct.py
|
||||
src/transformers/models/deprecated/mctct/feature_extraction_mctct.py
|
||||
src/transformers/models/deprecated/mctct/modeling_mctct.py
|
||||
src/transformers/models/deprecated/mctct/processing_mctct.py
|
||||
src/transformers/models/deprecated/mmbt/configuration_mmbt.py
|
||||
src/transformers/models/deprecated/mmbt/modeling_mmbt.py
|
||||
src/transformers/models/deprecated/open_llama/configuration_open_llama.py
|
||||
src/transformers/models/deprecated/open_llama/modeling_open_llama.py
|
||||
src/transformers/models/deprecated/retribert/configuration_retribert.py
|
||||
src/transformers/models/deprecated/retribert/modeling_retribert.py
|
||||
src/transformers/models/deprecated/retribert/tokenization_retribert.py
|
||||
src/transformers/models/deprecated/retribert/tokenization_retribert_fast.py
|
||||
src/transformers/models/deprecated/tapex/tokenization_tapex.py
|
||||
src/transformers/models/deprecated/trajectory_transformer/configuration_trajectory_transformer.py
|
||||
src/transformers/models/deprecated/trajectory_transformer/convert_trajectory_transformer_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/deprecated/trajectory_transformer/modeling_trajectory_transformer.py
|
||||
src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl.py
|
||||
src/transformers/models/deprecated/transfo_xl/modeling_transfo_xl_utilities.py
|
||||
src/transformers/models/deprecated/van/configuration_van.py
|
||||
src/transformers/models/deprecated/van/convert_van_to_pytorch.py
|
||||
src/transformers/models/deprecated/van/modeling_van.py
|
||||
src/transformers/models/detr/convert_detr_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/detr/convert_detr_to_pytorch.py
|
||||
src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/dinov2/configuration_dinov2.py
|
||||
src/transformers/models/dinov2/convert_dinov2_to_hf.py
|
||||
src/transformers/models/dinov2/modeling_dinov2.py
|
||||
src/transformers/models/distilbert/modeling_distilbert.py
|
||||
src/transformers/models/dit/convert_dit_unilm_to_pytorch.py
|
||||
src/transformers/models/donut/configuration_donut_swin.py
|
||||
src/transformers/models/donut/convert_donut_to_pytorch.py
|
||||
src/transformers/models/donut/modeling_donut_swin.py
|
||||
src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/dpr/modeling_dpr.py
|
||||
src/transformers/models/dpt/configuration_dpt.py
|
||||
src/transformers/models/dpt/convert_dpt_hybrid_to_pytorch.py
|
||||
src/transformers/models/dpt/convert_dpt_to_pytorch.py
|
||||
src/transformers/models/efficientnet/configuration_efficientnet.py
|
||||
src/transformers/models/efficientnet/convert_efficientnet_to_pytorch.py
|
||||
src/transformers/models/efficientnet/modeling_efficientnet.py
|
||||
src/transformers/models/encodec/configuration_encodec.py
|
||||
src/transformers/models/encodec/convert_encodec_checkpoint_to_pytorch.py
|
||||
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
|
||||
src/transformers/models/ernie/modeling_ernie.py
|
||||
src/transformers/models/esm/configuration_esm.py
|
||||
src/transformers/models/esm/convert_esm.py
|
||||
src/transformers/models/esm/modeling_esm.py
|
||||
src/transformers/models/esm/modeling_esmfold.py
|
||||
src/transformers/models/esm/openfold_utils/chunk_utils.py
|
||||
src/transformers/models/esm/openfold_utils/data_transforms.py
|
||||
src/transformers/models/esm/openfold_utils/feats.py
|
||||
src/transformers/models/esm/openfold_utils/loss.py
|
||||
src/transformers/models/esm/openfold_utils/protein.py
|
||||
src/transformers/models/esm/openfold_utils/residue_constants.py
|
||||
src/transformers/models/esm/openfold_utils/rigid_utils.py
|
||||
src/transformers/models/esm/openfold_utils/tensor_utils.py
|
||||
src/transformers/models/falcon/configuration_falcon.py
|
||||
src/transformers/models/falcon/modeling_falcon.py
|
||||
src/transformers/models/flaubert/configuration_flaubert.py
|
||||
src/transformers/models/flaubert/modeling_flaubert.py
|
||||
src/transformers/models/flava/convert_dalle_to_flava_codebook.py
|
||||
src/transformers/models/flava/convert_flava_original_pytorch_to_hf.py
|
||||
src/transformers/models/flava/modeling_flava.py
|
||||
src/transformers/models/fnet/modeling_fnet.py
|
||||
src/transformers/models/focalnet/configuration_focalnet.py
|
||||
src/transformers/models/focalnet/convert_focalnet_to_hf_format.py
|
||||
src/transformers/models/focalnet/modeling_focalnet.py
|
||||
src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/fsmt/modeling_fsmt.py
|
||||
src/transformers/models/funnel/configuration_funnel.py
|
||||
src/transformers/models/funnel/modeling_funnel.py
|
||||
src/transformers/models/fuyu/convert_fuyu_model_weights_to_hf.py
|
||||
src/transformers/models/gemma/configuration_gemma.py
|
||||
src/transformers/models/gemma/convert_gemma_weights_to_hf.py
|
||||
src/transformers/models/gemma/modeling_gemma.py
|
||||
src/transformers/models/git/configuration_git.py
|
||||
src/transformers/models/git/convert_git_to_pytorch.py
|
||||
src/transformers/models/glpn/configuration_glpn.py
|
||||
src/transformers/models/glpn/convert_glpn_to_pytorch.py
|
||||
src/transformers/models/gpt2/CONVERSION.md
|
||||
src/transformers/models/gpt_bigcode/configuration_gpt_bigcode.py
|
||||
src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
|
||||
src/transformers/models/gpt_neo/modeling_gpt_neo.py
|
||||
src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
|
||||
src/transformers/models/gpt_sw3/convert_megatron_to_pytorch.py
|
||||
src/transformers/models/gptj/configuration_gptj.py
|
||||
src/transformers/models/groupvit/configuration_groupvit.py
|
||||
src/transformers/models/groupvit/convert_groupvit_nvlab_to_hf.py
|
||||
src/transformers/models/hubert/configuration_hubert.py
|
||||
src/transformers/models/hubert/convert_distilhubert_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/hubert/convert_hubert_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/ibert/configuration_ibert.py
|
||||
src/transformers/models/ibert/modeling_ibert.py
|
||||
src/transformers/models/ibert/quant_modules.py
|
||||
src/transformers/models/idefics/configuration_idefics.py
|
||||
src/transformers/models/idefics/image_processing_idefics.py
|
||||
src/transformers/models/idefics/modeling_idefics.py
|
||||
src/transformers/models/idefics/perceiver.py
|
||||
src/transformers/models/idefics/processing_idefics.py
|
||||
src/transformers/models/idefics/vision.py
|
||||
src/transformers/models/informer/configuration_informer.py
|
||||
src/transformers/models/informer/modeling_informer.py
|
||||
src/transformers/models/instructblip/configuration_instructblip.py
|
||||
src/transformers/models/instructblip/convert_instructblip_original_to_pytorch.py
|
||||
src/transformers/models/instructblip/modeling_instructblip.py
|
||||
src/transformers/models/instructblip/processing_instructblip.py
|
||||
src/transformers/models/jamba/configuration_jamba.py
|
||||
src/transformers/models/jamba/modeling_jamba.py
|
||||
src/transformers/models/kosmos2/convert_kosmos2_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/led/configuration_led.py
|
||||
src/transformers/models/led/modeling_led.py
|
||||
src/transformers/models/levit/convert_levit_timm_to_pytorch.py
|
||||
src/transformers/models/levit/modeling_levit.py
|
||||
src/transformers/models/lilt/configuration_lilt.py
|
||||
src/transformers/models/llama/configuration_llama.py
|
||||
src/transformers/models/llama/convert_llama_weights_to_hf.py
|
||||
src/transformers/models/llama/modeling_llama.py
|
||||
src/transformers/models/llava/configuration_llava.py
|
||||
src/transformers/models/llava/modeling_llava.py
|
||||
src/transformers/models/llava_next/configuration_llava_next.py
|
||||
src/transformers/models/llava_next/modeling_llava_next.py
|
||||
src/transformers/models/longformer/configuration_longformer.py
|
||||
src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py
|
||||
src/transformers/models/longt5/configuration_longt5.py
|
||||
src/transformers/models/luke/configuration_luke.py
|
||||
src/transformers/models/luke/convert_luke_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/luke/modeling_luke.py
|
||||
src/transformers/models/lxmert/configuration_lxmert.py
|
||||
src/transformers/models/lxmert/modeling_lxmert.py
|
||||
src/transformers/models/m2m_100/convert_m2m100_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/m2m_100/modeling_m2m_100.py
|
||||
src/transformers/models/marian/configuration_marian.py
|
||||
src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py
|
||||
src/transformers/models/marian/convert_marian_to_pytorch.py
|
||||
src/transformers/models/markuplm/configuration_markuplm.py
|
||||
src/transformers/models/markuplm/feature_extraction_markuplm.py
|
||||
src/transformers/models/mask2former/convert_mask2former_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/maskformer/configuration_maskformer_swin.py
|
||||
src/transformers/models/maskformer/convert_maskformer_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/maskformer/convert_maskformer_resnet_to_pytorch.py
|
||||
src/transformers/models/maskformer/convert_maskformer_swin_to_pytorch.py
|
||||
src/transformers/models/maskformer/modeling_maskformer_swin.py
|
||||
src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py
|
||||
src/transformers/models/megatron_bert/modeling_megatron_bert.py
|
||||
src/transformers/models/megatron_gpt2/checkpoint_reshaping_and_interoperability.py
|
||||
src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py
|
||||
src/transformers/models/mgp_str/configuration_mgp_str.py
|
||||
src/transformers/models/mgp_str/modeling_mgp_str.py
|
||||
src/transformers/models/mistral/configuration_mistral.py
|
||||
src/transformers/models/mistral/modeling_mistral.py
|
||||
src/transformers/models/mixtral/configuration_mixtral.py
|
||||
src/transformers/models/mixtral/modeling_mixtral.py
|
||||
src/transformers/models/mluke/convert_mluke_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/mobilenet_v1/configuration_mobilenet_v1.py
|
||||
src/transformers/models/mobilenet_v2/configuration_mobilenet_v2.py
|
||||
src/transformers/models/mobilevit/configuration_mobilevit.py
|
||||
src/transformers/models/mobilevit/convert_mlcvnets_to_pytorch.py
|
||||
src/transformers/models/mobilevitv2/convert_mlcvnets_to_pytorch.py
|
||||
src/transformers/models/mpnet/configuration_mpnet.py
|
||||
src/transformers/models/mpnet/modeling_mpnet.py
|
||||
src/transformers/models/mpt/configuration_mpt.py
|
||||
src/transformers/models/mpt/modeling_mpt.py
|
||||
src/transformers/models/mra/configuration_mra.py
|
||||
src/transformers/models/mra/convert_mra_pytorch_to_pytorch.py
|
||||
src/transformers/models/mra/modeling_mra.py
|
||||
src/transformers/models/mt5/configuration_mt5.py
|
||||
src/transformers/models/mt5/modeling_mt5.py
|
||||
src/transformers/models/musicgen/convert_musicgen_transformers.py
|
||||
src/transformers/models/musicgen_melody/convert_musicgen_melody_transformers.py
|
||||
src/transformers/models/mvp/modeling_mvp.py
|
||||
src/transformers/models/nllb_moe/configuration_nllb_moe.py
|
||||
src/transformers/models/nllb_moe/convert_nllb_moe_sharded_original_checkpoint_to_pytorch.py
|
||||
src/transformers/models/nllb_moe/modeling_nllb_moe.py
|
||||
src/transformers/models/nougat/convert_nougat_to_hf.py
|
||||
src/transformers/models/nystromformer/configuration_nystromformer.py
|
||||
src/transformers/models/nystromformer/convert_nystromformer_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/nystromformer/modeling_nystromformer.py
|
||||
src/transformers/models/oneformer/convert_to_hf_oneformer.py
|
||||
src/transformers/models/openai/modeling_openai.py
|
||||
src/transformers/models/opt/convert_opt_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/owlvit/configuration_owlvit.py
|
||||
src/transformers/models/pegasus_x/modeling_pegasus_x.py
|
||||
src/transformers/models/perceiver/configuration_perceiver.py
|
||||
src/transformers/models/perceiver/convert_perceiver_haiku_to_pytorch.py
|
||||
src/transformers/models/persimmon/convert_persimmon_weights_to_hf.py
|
||||
src/transformers/models/persimmon/modeling_persimmon.py
|
||||
src/transformers/models/pix2struct/configuration_pix2struct.py
|
||||
src/transformers/models/pix2struct/convert_pix2struct_original_pytorch_to_hf.py
|
||||
src/transformers/models/pix2struct/image_processing_pix2struct.py
|
||||
src/transformers/models/pix2struct/processing_pix2struct.py
|
||||
src/transformers/models/plbart/convert_plbart_original_checkpoint_to_torch.py
|
||||
src/transformers/models/poolformer/convert_poolformer_original_to_pytorch.py
|
||||
src/transformers/models/pop2piano/convert_pop2piano_weights_to_hf.py
|
||||
src/transformers/models/pop2piano/feature_extraction_pop2piano.py
|
||||
src/transformers/models/pop2piano/processing_pop2piano.py
|
||||
src/transformers/models/pop2piano/tokenization_pop2piano.py
|
||||
src/transformers/models/prophetnet/configuration_prophetnet.py
|
||||
src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/prophetnet/modeling_prophetnet.py
|
||||
src/transformers/models/pvt/configuration_pvt.py
|
||||
src/transformers/models/pvt/convert_pvt_to_pytorch.py
|
||||
src/transformers/models/pvt/image_processing_pvt.py
|
||||
src/transformers/models/pvt/modeling_pvt.py
|
||||
src/transformers/models/qwen2/configuration_qwen2.py
|
||||
src/transformers/models/qwen2/modeling_qwen2.py
|
||||
src/transformers/models/qwen2/tokenization_qwen2.py
|
||||
src/transformers/models/qwen2/tokenization_qwen2_fast.py
|
||||
src/transformers/models/qwen2_moe/configuration_qwen2_moe.py
|
||||
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
|
||||
src/transformers/models/rag/configuration_rag.py
|
||||
src/transformers/models/rag/modeling_rag.py
|
||||
src/transformers/models/rag/retrieval_rag.py
|
||||
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
|
||||
src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py
|
||||
src/transformers/models/regnet/configuration_regnet.py
|
||||
src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py
|
||||
src/transformers/models/regnet/convert_regnet_to_pytorch.py
|
||||
src/transformers/models/rembert/configuration_rembert.py
|
||||
src/transformers/models/rembert/modeling_rembert.py
|
||||
src/transformers/models/resnet/convert_resnet_to_pytorch.py
|
||||
src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/roberta_prelayernorm/convert_roberta_prelayernorm_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/roc_bert/configuration_roc_bert.py
|
||||
src/transformers/models/roformer/modeling_roformer.py
|
||||
src/transformers/models/rwkv/configuration_rwkv.py
|
||||
src/transformers/models/rwkv/convert_rwkv_checkpoint_to_hf.py
|
||||
src/transformers/models/rwkv/modeling_rwkv.py
|
||||
src/transformers/models/sam/configuration_sam.py
|
||||
src/transformers/models/sam/convert_sam_to_hf.py
|
||||
src/transformers/models/sam/image_processing_sam.py
|
||||
src/transformers/models/sam/modeling_sam.py
|
||||
src/transformers/models/sam/processing_sam.py
|
||||
src/transformers/models/seamless_m4t/convert_fairseq2_to_hf.py
|
||||
src/transformers/models/seamless_m4t_v2/convert_fairseq2_to_hf.py
|
||||
src/transformers/models/segformer/configuration_segformer.py
|
||||
src/transformers/models/segformer/convert_segformer_original_to_pytorch.py
|
||||
src/transformers/models/sew/convert_sew_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/sew_d/convert_sew_d_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py
|
||||
src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py
|
||||
src/transformers/models/speech_encoder_decoder/convert_speech_to_text_wav2vec2_seq2seq_original_to_pytorch.py
|
||||
src/transformers/models/speecht5/configuration_speecht5.py
|
||||
src/transformers/models/speecht5/convert_hifigan.py
|
||||
src/transformers/models/speecht5/convert_speecht5_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/speecht5/number_normalizer.py
|
||||
src/transformers/models/splinter/configuration_splinter.py
|
||||
src/transformers/models/splinter/modeling_splinter.py
|
||||
src/transformers/models/squeezebert/modeling_squeezebert.py
|
||||
src/transformers/models/stablelm/modeling_stablelm.py
|
||||
src/transformers/models/starcoder2/modeling_starcoder2.py
|
||||
src/transformers/models/swiftformer/configuration_swiftformer.py
|
||||
src/transformers/models/swiftformer/convert_swiftformer_original_to_hf.py
|
||||
src/transformers/models/swiftformer/modeling_swiftformer.py
|
||||
src/transformers/models/swin/convert_swin_simmim_to_pytorch.py
|
||||
src/transformers/models/swin/convert_swin_timm_to_pytorch.py
|
||||
src/transformers/models/swin2sr/configuration_swin2sr.py
|
||||
src/transformers/models/swin2sr/convert_swin2sr_original_to_pytorch.py
|
||||
src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py
|
||||
src/transformers/models/swinv2/modeling_swinv2.py
|
||||
src/transformers/models/switch_transformers/configuration_switch_transformers.py
|
||||
src/transformers/models/switch_transformers/convert_big_switch.py
|
||||
src/transformers/models/switch_transformers/modeling_switch_transformers.py
|
||||
src/transformers/models/t5/configuration_t5.py
|
||||
src/transformers/models/t5/convert_t5x_checkpoint_to_pytorch.py
|
||||
src/transformers/models/t5/modeling_t5.py
|
||||
src/transformers/models/table_transformer/configuration_table_transformer.py
|
||||
src/transformers/models/table_transformer/convert_table_transformer_to_hf.py
|
||||
src/transformers/models/table_transformer/convert_table_transformer_to_hf_no_timm.py
|
||||
src/transformers/models/tapas/configuration_tapas.py
|
||||
src/transformers/models/tapas/modeling_tapas.py
|
||||
src/transformers/models/timesformer/convert_timesformer_to_pytorch.py
|
||||
src/transformers/models/timm_backbone/configuration_timm_backbone.py
|
||||
src/transformers/models/timm_backbone/modeling_timm_backbone.py
|
||||
src/transformers/models/trocr/convert_trocr_unilm_to_pytorch.py
|
||||
src/transformers/models/umt5/configuration_umt5.py
|
||||
src/transformers/models/umt5/convert_umt5_checkpoint_to_pytorch.py
|
||||
src/transformers/models/umt5/modeling_umt5.py
|
||||
src/transformers/models/unispeech/convert_unispeech_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/unispeech_sat/configuration_unispeech_sat.py
|
||||
src/transformers/models/unispeech_sat/convert_unispeech_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/unispeech_sat/convert_unispeech_sat_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/upernet/configuration_upernet.py
|
||||
src/transformers/models/upernet/convert_convnext_upernet_to_pytorch.py
|
||||
src/transformers/models/upernet/convert_swin_upernet_to_pytorch.py
|
||||
src/transformers/models/videomae/configuration_videomae.py
|
||||
src/transformers/models/videomae/convert_videomae_to_pytorch.py
|
||||
src/transformers/models/vilt/configuration_vilt.py
|
||||
src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
|
||||
src/transformers/models/vipllava/configuration_vipllava.py
|
||||
src/transformers/models/vipllava/modeling_vipllava.py
|
||||
src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py
|
||||
src/transformers/models/visual_bert/convert_visual_bert_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/visual_bert/modeling_visual_bert.py
|
||||
src/transformers/models/vit/convert_dino_to_pytorch.py
|
||||
src/transformers/models/vit/convert_vit_timm_to_pytorch.py
|
||||
src/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py
|
||||
src/transformers/models/vit_msn/configuration_vit_msn.py
|
||||
src/transformers/models/vit_msn/convert_msn_to_pytorch.py
|
||||
src/transformers/models/vivit/configuration_vivit.py
|
||||
src/transformers/models/vivit/image_processing_vivit.py
|
||||
src/transformers/models/vivit/modeling_vivit.py
|
||||
src/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py
|
||||
src/transformers/models/wav2vec2_conformer/convert_wav2vec2_conformer_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wavlm/convert_wavlm_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/wavlm/convert_wavlm_original_s3prl_checkpoint_to_pytorch.py
|
||||
src/transformers/models/whisper/convert_openai_to_hf.py
|
||||
src/transformers/models/whisper/english_normalizer.py
|
||||
src/transformers/models/x_clip/configuration_x_clip.py
|
||||
src/transformers/models/x_clip/convert_x_clip_original_pytorch_to_hf.py
|
||||
src/transformers/models/xglm/configuration_xglm.py
|
||||
src/transformers/models/xglm/convert_xglm_original_ckpt_to_trfms.py
|
||||
src/transformers/models/xglm/modeling_xglm.py
|
||||
src/transformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/xlm/modeling_xlm.py
|
||||
src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
|
||||
src/transformers/models/xlm_roberta_xl/convert_xlm_roberta_xl_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
|
||||
src/transformers/models/xlnet/modeling_xlnet.py
|
||||
src/transformers/models/xmod/convert_xmod_original_pytorch_checkpoint_to_pytorch.py
|
||||
src/transformers/models/yolos/convert_yolos_to_pytorch.py
|
||||
src/transformers/models/yoso/convert_yoso_pytorch_to_pytorch.py
|
||||
src/transformers/models/yoso/modeling_yoso.py
|
||||
src/transformers/models/zamba/configuration_zamba.py
|
||||
src/transformers/models/zamba/modeling_zamba.py
|
||||
src/transformers/onnx/__main__.py
|
||||
src/transformers/onnx/config.py
|
||||
src/transformers/onnx/convert.py
|
||||
src/transformers/onnx/features.py
|
||||
src/transformers/onnx/utils.py
|
||||
src/transformers/optimization.py
|
||||
src/transformers/pipelines/audio_classification.py
|
||||
src/transformers/pipelines/audio_utils.py
|
||||
src/transformers/pipelines/automatic_speech_recognition.py
|
||||
src/transformers/pipelines/base.py
|
||||
src/transformers/pipelines/depth_estimation.py
|
||||
src/transformers/pipelines/document_question_answering.py
|
||||
src/transformers/pipelines/feature_extraction.py
|
||||
src/transformers/pipelines/fill_mask.py
|
||||
src/transformers/pipelines/image_classification.py
|
||||
src/transformers/pipelines/image_segmentation.py
|
||||
src/transformers/pipelines/image_to_text.py
|
||||
src/transformers/pipelines/mask_generation.py
|
||||
src/transformers/pipelines/object_detection.py
|
||||
src/transformers/pipelines/pt_utils.py
|
||||
src/transformers/pipelines/question_answering.py
|
||||
src/transformers/pipelines/table_question_answering.py
|
||||
src/transformers/pipelines/text_classification.py
|
||||
src/transformers/pipelines/token_classification.py
|
||||
src/transformers/pipelines/video_classification.py
|
||||
src/transformers/pipelines/visual_question_answering.py
|
||||
src/transformers/pipelines/zero_shot_audio_classification.py
|
||||
src/transformers/pipelines/zero_shot_classification.py
|
||||
src/transformers/pipelines/zero_shot_image_classification.py
|
||||
src/transformers/pipelines/zero_shot_object_detection.py
|
||||
src/transformers/processing_utils.py
|
||||
src/transformers/pytorch_utils.py
|
||||
src/transformers/quantizers/auto.py
|
||||
src/transformers/quantizers/base.py
|
||||
src/transformers/quantizers/quantizer_awq.py
|
||||
src/transformers/quantizers/quantizer_bnb_4bit.py
|
||||
src/transformers/quantizers/quantizer_bnb_8bit.py
|
||||
src/transformers/quantizers/quantizer_gptq.py
|
||||
src/transformers/quantizers/quantizers_utils.py
|
||||
src/transformers/sagemaker/trainer_sm.py
|
||||
src/transformers/sagemaker/training_args_sm.py
|
||||
src/transformers/testing_utils.py
|
||||
src/transformers/time_series_utils.py
|
||||
src/transformers/tokenization_utils.py
|
||||
src/transformers/tokenization_utils_base.py
|
||||
src/transformers/tokenization_utils_fast.py
|
||||
src/transformers/trainer.py
|
||||
src/transformers/trainer_callback.py
|
||||
src/transformers/trainer_pt_utils.py
|
||||
src/transformers/trainer_seq2seq.py
|
||||
src/transformers/trainer_utils.py
|
||||
src/transformers/training_args.py
|
||||
src/transformers/training_args_seq2seq.py
|
||||
src/transformers/utils/backbone_utils.py
|
||||
src/transformers/utils/bitsandbytes.py
|
||||
src/transformers/utils/constants.py
|
||||
src/transformers/utils/doc.py
|
||||
src/transformers/utils/dummy_detectron2_objects.py
|
||||
src/transformers/utils/dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects.py
|
||||
src/transformers/utils/dummy_music_objects.py
|
||||
src/transformers/utils/dummy_pt_objects.py
|
||||
src/transformers/utils/dummy_sentencepiece_and_tokenizers_objects.py
|
||||
src/transformers/utils/dummy_sentencepiece_objects.py
|
||||
src/transformers/utils/dummy_speech_objects.py
|
||||
src/transformers/utils/dummy_tokenizers_objects.py
|
||||
src/transformers/utils/dummy_vision_objects.py
|
||||
src/transformers/utils/fx.py
|
||||
src/transformers/utils/generic.py
|
||||
src/transformers/utils/hp_naming.py
|
||||
src/transformers/utils/hub.py
|
||||
src/transformers/utils/import_utils.py
|
||||
src/transformers/utils/logging.py
|
||||
src/transformers/utils/model_parallel_utils.py
|
||||
src/transformers/utils/notebook.py
|
||||
src/transformers/utils/peft_utils.py
|
||||
src/transformers/utils/quantization_config.py
|
||||
src/transformers/utils/sentencepiece_model_pb2.py
|
||||
src/transformers/utils/sentencepiece_model_pb2_new.py
|
||||
src/transformers/utils/versions.py
|
||||
1605
transformers/utils/notification_service.py
Normal file
1605
transformers/utils/notification_service.py
Normal file
File diff suppressed because it is too large
Load Diff
384
transformers/utils/notification_service_doc_tests.py
Normal file
384
transformers/utils/notification_service_doc_tests.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
from get_ci_error_statistics import get_jobs
|
||||
from slack_sdk import WebClient
|
||||
|
||||
|
||||
client = WebClient(token=os.environ["CI_SLACK_BOT_TOKEN"])
|
||||
|
||||
|
||||
def handle_test_results(test_results):
|
||||
expressions = test_results.split(" ")
|
||||
|
||||
failed = 0
|
||||
success = 0
|
||||
|
||||
# When the output is short enough, the output is surrounded by = signs: "== OUTPUT =="
|
||||
# When it is too long, those signs are not present.
|
||||
time_spent = expressions[-2] if "=" in expressions[-1] else expressions[-1]
|
||||
|
||||
for i, expression in enumerate(expressions):
|
||||
if "failed" in expression:
|
||||
failed += int(expressions[i - 1])
|
||||
if "passed" in expression:
|
||||
success += int(expressions[i - 1])
|
||||
|
||||
return failed, success, time_spent
|
||||
|
||||
|
||||
def extract_first_line_failure(failures_short_lines):
|
||||
failures = {}
|
||||
file = None
|
||||
in_error = False
|
||||
for line in failures_short_lines.split("\n"):
|
||||
if re.search(r"_ \[doctest\]", line):
|
||||
in_error = True
|
||||
file = line.split(" ")[2]
|
||||
elif in_error and not line.split(" ")[0].isdigit():
|
||||
failures[file] = line
|
||||
in_error = False
|
||||
|
||||
return failures
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(self, title: str, doc_test_results: dict):
|
||||
self.title = title
|
||||
|
||||
self.n_success = sum(job_result["n_success"] for job_result in doc_test_results.values())
|
||||
self.n_failures = sum(job_result["n_failures"] for job_result in doc_test_results.values())
|
||||
self.n_tests = self.n_success + self.n_failures
|
||||
|
||||
# Failures and success of the modeling tests
|
||||
self.doc_test_results = doc_test_results
|
||||
|
||||
@property
|
||||
def time(self) -> str:
|
||||
all_results = [*self.doc_test_results.values()]
|
||||
time_spent = [r["time_spent"].split(", ")[0] for r in all_results if len(r["time_spent"])]
|
||||
total_secs = 0
|
||||
|
||||
for time in time_spent:
|
||||
time_parts = time.split(":")
|
||||
|
||||
# Time can be formatted as xx:xx:xx, as .xx, or as x.xx if the time spent was less than a minute.
|
||||
if len(time_parts) == 1:
|
||||
time_parts = [0, 0, time_parts[0]]
|
||||
|
||||
hours, minutes, seconds = int(time_parts[0]), int(time_parts[1]), float(time_parts[2])
|
||||
total_secs += hours * 3600 + minutes * 60 + seconds
|
||||
|
||||
hours, minutes, seconds = total_secs // 3600, (total_secs % 3600) // 60, total_secs % 60
|
||||
return f"{int(hours)}h{int(minutes)}m{int(seconds)}s"
|
||||
|
||||
@property
|
||||
def header(self) -> dict:
|
||||
return {"type": "header", "text": {"type": "plain_text", "text": self.title}}
|
||||
|
||||
@property
|
||||
def no_failures(self) -> dict:
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": f"🌞 There were no failures: all {self.n_tests} tests passed. The suite ran in {self.time}.",
|
||||
"emoji": True,
|
||||
},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def failures(self) -> dict:
|
||||
return {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": (
|
||||
f"There were {self.n_failures} failures, out of {self.n_tests} tests.\nThe suite ran in"
|
||||
f" {self.time}."
|
||||
),
|
||||
"emoji": True,
|
||||
},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def category_failures(self) -> list[dict]:
|
||||
failure_blocks = []
|
||||
|
||||
MAX_ERROR_TEXT = 3000 - len("The following examples had failures:\n\n\n\n") - len("[Truncated]\n")
|
||||
line_length = 40
|
||||
category_failures = {k: v["failed"] for k, v in doc_test_results.items() if isinstance(v, dict)}
|
||||
|
||||
def single_category_failures(category, failures):
|
||||
text = ""
|
||||
if len(failures) == 0:
|
||||
return ""
|
||||
text += f"*{category} failures*:".ljust(line_length // 2).rjust(line_length // 2) + "\n"
|
||||
|
||||
for idx, failure in enumerate(failures):
|
||||
new_text = text + f"`{failure}`\n"
|
||||
if len(new_text) > MAX_ERROR_TEXT:
|
||||
text = text + "[Truncated]\n"
|
||||
break
|
||||
text = new_text
|
||||
|
||||
return text
|
||||
|
||||
for category, failures in category_failures.items():
|
||||
report = single_category_failures(category, failures)
|
||||
if len(report) == 0:
|
||||
continue
|
||||
block = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": f"The following examples had failures:\n\n\n{report}\n",
|
||||
},
|
||||
}
|
||||
failure_blocks.append(block)
|
||||
|
||||
return failure_blocks
|
||||
|
||||
@property
|
||||
def payload(self) -> str:
|
||||
blocks = [self.header]
|
||||
|
||||
if self.n_failures > 0:
|
||||
blocks.append(self.failures)
|
||||
|
||||
if self.n_failures > 0:
|
||||
blocks.extend(self.category_failures)
|
||||
|
||||
if self.n_failures == 0:
|
||||
blocks.append(self.no_failures)
|
||||
|
||||
return json.dumps(blocks)
|
||||
|
||||
@staticmethod
|
||||
def error_out():
|
||||
payload = [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "There was an issue running the tests.",
|
||||
},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/transformers/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
print("Sending the following payload")
|
||||
print(json.dumps({"blocks": json.loads(payload)}))
|
||||
|
||||
client.chat_postMessage(
|
||||
channel=SLACK_REPORT_CHANNEL_ID,
|
||||
text="There was an issue running the tests.",
|
||||
blocks=payload,
|
||||
)
|
||||
|
||||
def post(self):
|
||||
print("Sending the following payload")
|
||||
print(json.dumps({"blocks": json.loads(self.payload)}))
|
||||
|
||||
text = f"{self.n_failures} failures out of {self.n_tests} tests," if self.n_failures else "All tests passed."
|
||||
|
||||
self.thread_ts = client.chat_postMessage(
|
||||
channel=SLACK_REPORT_CHANNEL_ID,
|
||||
blocks=self.payload,
|
||||
text=text,
|
||||
)
|
||||
|
||||
def get_reply_blocks(self, job_name, job_link, failures, text):
|
||||
# `text` must be less than 3001 characters in Slack SDK
|
||||
# keep some room for adding "[Truncated]" when necessary
|
||||
MAX_ERROR_TEXT = 3000 - len("[Truncated]")
|
||||
|
||||
failure_text = ""
|
||||
for key, value in failures.items():
|
||||
new_text = failure_text + f"*{key}*\n_{value}_\n\n"
|
||||
if len(new_text) > MAX_ERROR_TEXT:
|
||||
# `failure_text` here has length <= 3000
|
||||
failure_text = failure_text + "[Truncated]"
|
||||
break
|
||||
# `failure_text` here has length <= MAX_ERROR_TEXT
|
||||
failure_text = new_text
|
||||
|
||||
title = job_name
|
||||
content = {"type": "section", "text": {"type": "mrkdwn", "text": text}}
|
||||
|
||||
if job_link is not None:
|
||||
content["accessory"] = {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "GitHub Action job", "emoji": True},
|
||||
"url": job_link,
|
||||
}
|
||||
|
||||
return [
|
||||
{"type": "header", "text": {"type": "plain_text", "text": title, "emoji": True}},
|
||||
content,
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": failure_text}},
|
||||
]
|
||||
|
||||
def post_reply(self):
|
||||
if self.thread_ts is None:
|
||||
raise ValueError("Can only post reply if a post has been made.")
|
||||
|
||||
sorted_dict = sorted(self.doc_test_results.items(), key=lambda t: t[0])
|
||||
for job_name, job_result in sorted_dict:
|
||||
if len(job_result["failures"]) > 0:
|
||||
text = f"*Num failures* :{len(job_result['failed'])} \n"
|
||||
failures = job_result["failures"]
|
||||
blocks = self.get_reply_blocks(job_name, job_result["job_link"], failures, text=text)
|
||||
|
||||
print("Sending the following reply")
|
||||
print(json.dumps({"blocks": blocks}))
|
||||
|
||||
client.chat_postMessage(
|
||||
channel=SLACK_REPORT_CHANNEL_ID,
|
||||
text=f"Results for {job_name}",
|
||||
blocks=blocks,
|
||||
thread_ts=self.thread_ts["ts"],
|
||||
)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def retrieve_artifact(name: str):
|
||||
_artifact = {}
|
||||
|
||||
if os.path.exists(name):
|
||||
files = os.listdir(name)
|
||||
for file in files:
|
||||
try:
|
||||
with open(os.path.join(name, file), encoding="utf-8") as f:
|
||||
_artifact[file.split(".")[0]] = f.read()
|
||||
except UnicodeDecodeError as e:
|
||||
raise ValueError(f"Could not open {os.path.join(name, file)}.") from e
|
||||
|
||||
return _artifact
|
||||
|
||||
|
||||
def retrieve_available_artifacts():
|
||||
class Artifact:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.paths = []
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def add_path(self, path: str):
|
||||
self.paths.append({"name": self.name, "path": path})
|
||||
|
||||
_available_artifacts: dict[str, Artifact] = {}
|
||||
|
||||
directories = filter(os.path.isdir, os.listdir())
|
||||
for directory in directories:
|
||||
artifact_name = directory
|
||||
if artifact_name not in _available_artifacts:
|
||||
_available_artifacts[artifact_name] = Artifact(artifact_name)
|
||||
|
||||
_available_artifacts[artifact_name].add_path(directory)
|
||||
|
||||
return _available_artifacts
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SLACK_REPORT_CHANNEL_ID = os.environ["SLACK_REPORT_CHANNEL"]
|
||||
|
||||
github_actions_jobs = get_jobs(
|
||||
workflow_run_id=os.environ["GITHUB_RUN_ID"], token=os.environ["ACCESS_REPO_INFO_TOKEN"]
|
||||
)
|
||||
|
||||
artifact_name_to_job_map = {}
|
||||
for job in github_actions_jobs:
|
||||
for step in job["steps"]:
|
||||
if step["name"].startswith("Test suite reports artifacts: "):
|
||||
artifact_name = step["name"][len("Test suite reports artifacts: ") :]
|
||||
artifact_name_to_job_map[artifact_name] = job
|
||||
break
|
||||
|
||||
available_artifacts = retrieve_available_artifacts()
|
||||
|
||||
doc_test_results = {}
|
||||
# `artifact_key` is the artifact path
|
||||
for artifact_obj in available_artifacts.values():
|
||||
artifact_path = artifact_obj.paths[0]
|
||||
if not artifact_path["path"].startswith("doc_tests_gpu_test_reports_"):
|
||||
continue
|
||||
|
||||
# change "_" back to "/" (to show the job name as path)
|
||||
job_name = artifact_path["path"].replace("doc_tests_gpu_test_reports_", "").replace("_", "/")
|
||||
|
||||
# This dict (for each job) will contain all the information relative to each doc test job, in particular:
|
||||
# - failed: list of failed tests
|
||||
# - failures: dict in the format 'test': 'error_message'
|
||||
job_result = {}
|
||||
doc_test_results[job_name] = job_result
|
||||
|
||||
job = artifact_name_to_job_map[artifact_path["path"]]
|
||||
job_result["job_link"] = job["html_url"]
|
||||
job_result["category"] = "Python Examples" if job_name.startswith("src/") else "MD Examples"
|
||||
|
||||
artifact = retrieve_artifact(artifact_path["path"])
|
||||
if "stats" in artifact:
|
||||
failed, success, time_spent = handle_test_results(artifact["stats"])
|
||||
job_result["n_failures"] = failed
|
||||
job_result["n_success"] = success
|
||||
job_result["time_spent"] = time_spent[1:-1] + ", "
|
||||
job_result["failed"] = []
|
||||
job_result["failures"] = {}
|
||||
|
||||
all_failures = extract_first_line_failure(artifact["failures_short"])
|
||||
for line in artifact["summary_short"].split("\n"):
|
||||
if re.search("FAILED", line):
|
||||
line = line.replace("FAILED ", "")
|
||||
line = line.split()[0].replace("\n", "")
|
||||
|
||||
if "::" in line:
|
||||
file_path, test = line.split("::")
|
||||
else:
|
||||
file_path, test = line, line
|
||||
|
||||
job_result["failed"].append(test)
|
||||
failure = all_failures.get(test, "N/A")
|
||||
job_result["failures"][test] = failure
|
||||
|
||||
# Save and to be uploaded as artifact
|
||||
os.makedirs("doc_test_results", exist_ok=True)
|
||||
with open("doc_test_results/doc_test_results.json", "w", encoding="UTF-8") as fp:
|
||||
json.dump(doc_test_results, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
message = Message("🤗 Results of the doc tests.", doc_test_results)
|
||||
message.post()
|
||||
message.post_reply()
|
||||
156
transformers/utils/patch_helper.py
Normal file
156
transformers/utils/patch_helper.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This should help you prepare a patch, automatically extracting the commits to cherry-pick
|
||||
in chronological order to avoid merge conflicts. An equivalent way to do this is to use
|
||||
`git log --pretty=oneline HEAD...v4.41.0` and grep.
|
||||
|
||||
Potential TODO: automatically cherry-picks them.
|
||||
|
||||
Pass in a list of PR:
|
||||
`python utils/patch_helper.py --prs 31108 31054 31008 31010 31004`
|
||||
will produce the following:
|
||||
```bash
|
||||
Skipping invalid version tag: list
|
||||
Skipping invalid version tag: localattn1
|
||||
Git cherry-pick commands to run:
|
||||
git cherry-pick 03935d300d60110bb86edb49d2315089cfb19789 #2024-05-24 11:00:59+02:00
|
||||
git cherry-pick bdb9106f247fca48a71eb384be25dbbd29b065a8 #2024-05-24 19:02:55+02:00
|
||||
git cherry-pick 84c4b72ee99e8e65a8a5754a5f9d6265b45cf67e #2024-05-27 10:34:14+02:00
|
||||
git cherry-pick 936ab7bae5e040ec58994cb722dd587b9ab26581 #2024-05-28 11:56:05+02:00
|
||||
git cherry-pick 0bef4a273825d2cfc52ddfe62ba486ee61cc116f #2024-05-29 13:33:26+01:00
|
||||
```
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
import transformers
|
||||
|
||||
|
||||
LABEL = "for patch" # Replace with your label
|
||||
REPO = "huggingface/transformers" # Optional if already in correct repo
|
||||
|
||||
|
||||
def get_release_branch_name():
|
||||
"""Derive branch name from transformers version."""
|
||||
major, minor, *_ = transformers.__version__.split(".")
|
||||
major = int(major)
|
||||
minor = int(minor)
|
||||
|
||||
if minor == 0:
|
||||
# Handle major version rollback, e.g., from 5.0 to 4.latest (if ever needed)
|
||||
major -= 1
|
||||
# You'll need logic to determine the last minor of the previous major version
|
||||
raise ValueError("Minor version is 0; need logic to find previous major version's last minor")
|
||||
|
||||
return f"v{major}.{minor}-release"
|
||||
|
||||
|
||||
def checkout_branch(branch):
|
||||
"""Checkout the target branch."""
|
||||
try:
|
||||
subprocess.run(["git", "checkout", branch], check=True)
|
||||
print(f"✅ Checked out branch: {branch}")
|
||||
except subprocess.CalledProcessError:
|
||||
print(f"❌ Failed to checkout branch: {branch}. Does it exist?")
|
||||
exit(1)
|
||||
|
||||
|
||||
def get_prs_by_label(label):
|
||||
"""Call gh CLI to get PRs with a specific label."""
|
||||
cmd = [
|
||||
"gh",
|
||||
"pr",
|
||||
"list",
|
||||
"--label",
|
||||
label,
|
||||
"--state",
|
||||
"all",
|
||||
"--json",
|
||||
"number,title,mergeCommit,url",
|
||||
"--limit",
|
||||
"100",
|
||||
]
|
||||
result = subprocess.run(cmd, check=False, capture_output=True, text=True)
|
||||
result.check_returncode()
|
||||
prs = json.loads(result.stdout)
|
||||
for pr in prs:
|
||||
is_merged = pr.get("mergeCommit", {})
|
||||
if is_merged:
|
||||
pr["oid"] = is_merged.get("oid")
|
||||
return prs
|
||||
|
||||
|
||||
def get_commit_timestamp(commit_sha):
|
||||
"""Get UNIX timestamp of a commit using git."""
|
||||
result = subprocess.run(
|
||||
["git", "show", "-s", "--format=%ct", commit_sha], check=False, capture_output=True, text=True
|
||||
)
|
||||
result.check_returncode()
|
||||
return int(result.stdout.strip())
|
||||
|
||||
|
||||
def cherry_pick_commit(sha):
|
||||
"""Cherry-pick a given commit SHA."""
|
||||
try:
|
||||
subprocess.run(["git", "cherry-pick", sha], check=True)
|
||||
print(f"✅ Cherry-picked commit {sha}")
|
||||
except subprocess.CalledProcessError:
|
||||
print(f"⚠️ Failed to cherry-pick {sha}. Manual intervention required.")
|
||||
|
||||
|
||||
def commit_in_history(commit_sha, base_branch="HEAD"):
|
||||
"""Return True if commit is already part of base_branch history."""
|
||||
result = subprocess.run(
|
||||
["git", "merge-base", "--is-ancestor", commit_sha, base_branch],
|
||||
check=False,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
def main(verbose=False):
|
||||
branch = get_release_branch_name()
|
||||
checkout_branch(branch)
|
||||
prs = get_prs_by_label(LABEL)
|
||||
# Attach commit timestamps
|
||||
for pr in prs:
|
||||
sha = pr.get("oid")
|
||||
if sha:
|
||||
pr["timestamp"] = get_commit_timestamp(sha)
|
||||
else:
|
||||
print("\n" + "=" * 80)
|
||||
print(f"⚠️ WARNING: PR #{pr['number']} ({sha}) is NOT in main!")
|
||||
print("⚠️ A core maintainer must review this before cherry-picking.")
|
||||
print("=" * 80 + "\n")
|
||||
# Sort by commit timestamp (ascending)
|
||||
prs = [pr for pr in prs if pr.get("timestamp") is not None]
|
||||
prs.sort(key=lambda pr: pr["timestamp"])
|
||||
for pr in prs:
|
||||
sha = pr.get("oid")
|
||||
if sha:
|
||||
if commit_in_history(sha):
|
||||
if verbose:
|
||||
print(f"🔁 PR #{pr['number']} ({pr['title']}) already in history. Skipping.")
|
||||
else:
|
||||
print(f"🚀 PR #{pr['number']} ({pr['title']}) not in history. Cherry-picking...")
|
||||
cherry_pick_commit(sha)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
172
transformers/utils/pr_slow_ci_models.py
Normal file
172
transformers/utils/pr_slow_ci_models.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to get the models for which to run slow CI.
|
||||
|
||||
A new model added in a pull request will be included, as well as models specified in a GitHub pull request's comment
|
||||
with a prefix `run-slow`, `run_slow` or `run slow`. For example, the commit message `run_slow: bert, gpt2` will give
|
||||
`bert` and `gpt2`.
|
||||
|
||||
Usage:
|
||||
|
||||
```bash
|
||||
python utils/pr_slow_ci_models.py
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import re
|
||||
import string
|
||||
from pathlib import Path
|
||||
|
||||
from git import Repo
|
||||
|
||||
|
||||
PATH_TO_REPO = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
def get_new_python_files_between_commits(base_commit: str, commits: list[str]) -> list[str]:
|
||||
"""
|
||||
Get the list of added python files between a base commit and one or several commits.
|
||||
|
||||
Args:
|
||||
repo (`git.Repo`):
|
||||
A git repository (for instance the Transformers repo).
|
||||
base_commit (`str`):
|
||||
The commit reference of where to compare for the diff. This is the current commit, not the branching point!
|
||||
commits (`List[str]`):
|
||||
The list of commits with which to compare the repo at `base_commit` (so the branching point).
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of python files added between a base commit and one or several commits.
|
||||
"""
|
||||
code_diff = []
|
||||
for commit in commits:
|
||||
for diff_obj in commit.diff(base_commit):
|
||||
# We always add new python files
|
||||
if diff_obj.change_type == "A" and diff_obj.b_path.endswith(".py"):
|
||||
code_diff.append(diff_obj.b_path)
|
||||
|
||||
return code_diff
|
||||
|
||||
|
||||
def get_new_python_files(diff_with_last_commit=False) -> list[str]:
|
||||
"""
|
||||
Return a list of python files that have been added between the current head and the main branch.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of python files added.
|
||||
"""
|
||||
repo = Repo(PATH_TO_REPO)
|
||||
|
||||
try:
|
||||
# For the cases where the main branch exists locally
|
||||
main = repo.refs.main
|
||||
except AttributeError:
|
||||
# On GitHub Actions runners, it doesn't have local main branch
|
||||
main = repo.remotes.origin.refs.main
|
||||
|
||||
if not diff_with_last_commit:
|
||||
print(f"main is at {main.commit}")
|
||||
print(f"Current head is at {repo.head.commit}")
|
||||
|
||||
commits = repo.merge_base(main, repo.head)
|
||||
for commit in commits:
|
||||
print(f"Branching commit: {commit}")
|
||||
else:
|
||||
print(f"main is at {main.commit}")
|
||||
commits = main.commit.parents
|
||||
for commit in commits:
|
||||
print(f"Parent commit: {commit}")
|
||||
|
||||
return get_new_python_files_between_commits(repo.head.commit, commits)
|
||||
|
||||
|
||||
def get_new_model(diff_with_last_commit=False):
|
||||
new_files = get_new_python_files(diff_with_last_commit)
|
||||
reg = re.compile(r"src/transformers/models/(.*)/modeling_.*\.py")
|
||||
|
||||
new_model = ""
|
||||
for x in new_files:
|
||||
find_new_model = reg.findall(x)
|
||||
if len(find_new_model) > 0:
|
||||
new_model = find_new_model[0]
|
||||
# It's unlikely we have 2 new modeling files in a pull request.
|
||||
break
|
||||
return new_model
|
||||
|
||||
|
||||
def parse_message(message: str) -> str:
|
||||
"""
|
||||
Parses a GitHub pull request's comment to find the models specified in it to run slow CI.
|
||||
|
||||
Args:
|
||||
message (`str`): The body of a GitHub pull request's comment.
|
||||
|
||||
Returns:
|
||||
`str`: The substring in `message` after `run-slow`, run_slow` or run slow`. If no such prefix is found, the
|
||||
empty string is returned.
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
message = message.strip().lower()
|
||||
|
||||
# run-slow: model_1, model_2
|
||||
if not message.startswith(("run-slow", "run_slow", "run slow")):
|
||||
return ""
|
||||
message = message[len("run slow") :]
|
||||
# remove leading `:`
|
||||
while message.strip().startswith(":"):
|
||||
message = message.strip()[1:]
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def get_models(message: str):
|
||||
models = parse_message(message)
|
||||
return models.replace(",", " ").split()
|
||||
|
||||
|
||||
def check_model_names(model_name: str):
|
||||
allowed = string.ascii_letters + string.digits + "_"
|
||||
return not (model_name.startswith("_") or model_name.endswith("_")) and all(c in allowed for c in model_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--message", type=str, default="", help="The content of a comment.")
|
||||
parser.add_argument("--quantization", action="store_true", help="If we collect quantization tests")
|
||||
args = parser.parse_args()
|
||||
|
||||
new_model = get_new_model()
|
||||
specified_models = get_models(args.message)
|
||||
models = ([] if new_model == "" else [new_model]) + specified_models
|
||||
# a guard for strange model names
|
||||
models = [model for model in models if check_model_names(model)]
|
||||
|
||||
# Add prefix
|
||||
final_list = []
|
||||
for model in models:
|
||||
if not args.quantization:
|
||||
if os.path.isdir(f"tests/models/{model}"):
|
||||
final_list.append(f"models/{model}")
|
||||
elif os.path.isdir(f"tests/{model}") and model != "quantization":
|
||||
final_list.append(model)
|
||||
elif os.path.isdir(f"tests/quantization/{model}"):
|
||||
final_list.append(f"quantization/{model}")
|
||||
|
||||
print(sorted(set(final_list)))
|
||||
76
transformers/utils/print_env.py
Normal file
76
transformers/utils/print_env.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# this script dumps information about the environment
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import transformers
|
||||
from transformers import is_torch_hpu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
|
||||
print("Python version:", sys.version)
|
||||
print("transformers version:", transformers.__version__)
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
print("Torch version:", torch.__version__)
|
||||
accelerator = "NA"
|
||||
if torch.cuda.is_available():
|
||||
accelerator = "CUDA"
|
||||
elif is_torch_xpu_available():
|
||||
accelerator = "XPU"
|
||||
elif is_torch_hpu_available():
|
||||
accelerator = "HPU"
|
||||
|
||||
print("Torch accelerator:", accelerator)
|
||||
|
||||
if accelerator == "CUDA":
|
||||
print("Cuda version:", torch.version.cuda)
|
||||
print("CuDNN version:", torch.backends.cudnn.version())
|
||||
print("Number of GPUs available:", torch.cuda.device_count())
|
||||
print("NCCL version:", torch.cuda.nccl.version())
|
||||
elif accelerator == "XPU":
|
||||
print("SYCL version:", torch.version.xpu)
|
||||
print("Number of XPUs available:", torch.xpu.device_count())
|
||||
elif accelerator == "HPU":
|
||||
print("HPU version:", torch.__version__.split("+")[-1])
|
||||
print("Number of HPUs available:", torch.hpu.device_count())
|
||||
except ImportError:
|
||||
print("Torch version:", None)
|
||||
|
||||
try:
|
||||
import deepspeed
|
||||
|
||||
print("DeepSpeed version:", deepspeed.__version__)
|
||||
except ImportError:
|
||||
print("DeepSpeed version:", None)
|
||||
|
||||
|
||||
try:
|
||||
import torchcodec
|
||||
|
||||
versions = torchcodec._core.get_ffmpeg_library_versions()
|
||||
print("FFmpeg version:", versions["ffmpeg_version"])
|
||||
except ImportError:
|
||||
print("FFmpeg version:", None)
|
||||
except (AttributeError, KeyError, RuntimeError):
|
||||
print("Failed to get FFmpeg version")
|
||||
130
transformers/utils/process_bad_commit_report.py
Normal file
130
transformers/utils/process_bad_commit_report.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""An internal script to process `new_failures_with_bad_commit.json` produced by `utils/check_bad_commit.py`.
|
||||
|
||||
This is used by `.github/workflows/check_failed_model_tests.yml` to produce a slack report of the following form
|
||||
|
||||
```
|
||||
<{url}|New failed tests>
|
||||
{
|
||||
"GH_ydshieh": {
|
||||
"vit": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
from get_previous_daily_ci import get_last_daily_ci_run
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = HfApi()
|
||||
|
||||
job_name = os.environ.get("JOB_NAME")
|
||||
|
||||
with open("new_failures_with_bad_commit.json") as fp:
|
||||
data = json.load(fp)
|
||||
|
||||
with open(f"ci_results_{job_name}/job_links.json") as fp:
|
||||
job_links = json.load(fp)
|
||||
|
||||
# TODO: extend
|
||||
team_members = [
|
||||
"ArthurZucker",
|
||||
"Cyrilvallez",
|
||||
"LysandreJik",
|
||||
"MekkCyber",
|
||||
"Rocketknight1",
|
||||
"SunMarc",
|
||||
"ebezzam",
|
||||
"eustlb",
|
||||
"gante",
|
||||
"itazap",
|
||||
"ivarflakstad",
|
||||
"molbap",
|
||||
"muellerzr",
|
||||
"remi-or",
|
||||
"stevhliu",
|
||||
"vasqu",
|
||||
"ydshieh",
|
||||
"zucchini-nlp",
|
||||
]
|
||||
|
||||
# Counting the number of failures grouped by authors
|
||||
new_data = {}
|
||||
for model, model_result in data.items():
|
||||
for device, failed_tests in model_result.items():
|
||||
for failed_test in failed_tests:
|
||||
author = failed_test["author"]
|
||||
|
||||
if author not in team_members:
|
||||
author = failed_test["merged_by"]
|
||||
|
||||
if author not in new_data:
|
||||
new_data[author] = Counter()
|
||||
new_data[author].update([model])
|
||||
for author in new_data:
|
||||
new_data[author] = dict(new_data[author])
|
||||
|
||||
# Group by author
|
||||
new_data_full = {author: deepcopy(data) for author in new_data}
|
||||
for author, _data in new_data_full.items():
|
||||
for model, model_result in _data.items():
|
||||
for device, failed_tests in model_result.items():
|
||||
# prepare job_link and add it to each entry of new failed test information.
|
||||
# need to change from `single-gpu` to `single` and same for `multi-gpu` to match `job_link`.
|
||||
key = model
|
||||
if list(job_links.keys()) == [job_name]:
|
||||
key = job_name
|
||||
job_link = job_links[key][device.replace("-gpu", "")]
|
||||
|
||||
failed_tests = [x for x in failed_tests if x["author"] == author or x["merged_by"] == author]
|
||||
for x in failed_tests:
|
||||
x.update({"job_link": job_link})
|
||||
model_result[device] = failed_tests
|
||||
_data[model] = {k: v for k, v in model_result.items() if len(v) > 0}
|
||||
new_data_full[author] = {k: v for k, v in _data.items() if len(v) > 0}
|
||||
|
||||
# Upload to Hub and get the url
|
||||
# if it is not a scheduled run, upload the reports to a subfolder under `report_repo_folder`
|
||||
report_repo_subfolder = ""
|
||||
if os.getenv("GITHUB_EVENT_NAME") != "schedule":
|
||||
report_repo_subfolder = f"{os.getenv('GITHUB_RUN_NUMBER')}-{os.getenv('GITHUB_RUN_ID')}"
|
||||
report_repo_subfolder = f"runs/{report_repo_subfolder}"
|
||||
|
||||
workflow_run = get_last_daily_ci_run(
|
||||
token=os.environ["ACCESS_REPO_INFO_TOKEN"], workflow_run_id=os.getenv("GITHUB_RUN_ID")
|
||||
)
|
||||
workflow_run_created_time = workflow_run["created_at"]
|
||||
|
||||
report_repo_folder = workflow_run_created_time.split("T")[0]
|
||||
|
||||
if report_repo_subfolder:
|
||||
report_repo_folder = f"{report_repo_folder}/{report_repo_subfolder}"
|
||||
|
||||
report_repo_id = os.getenv("REPORT_REPO_ID")
|
||||
|
||||
with open("new_failures_with_bad_commit_grouped_by_authors.json", "w") as fp:
|
||||
json.dump(new_data_full, fp, ensure_ascii=False, indent=4)
|
||||
commit_info = api.upload_file(
|
||||
path_or_fileobj="new_failures_with_bad_commit_grouped_by_authors.json",
|
||||
path_in_repo=f"{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit_grouped_by_authors.json",
|
||||
repo_id=report_repo_id,
|
||||
repo_type="dataset",
|
||||
token=os.environ.get("TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN", None),
|
||||
)
|
||||
url = f"https://huggingface.co/datasets/{report_repo_id}/raw/{commit_info.oid}/{report_repo_folder}/ci_results_{job_name}/new_failures_with_bad_commit_grouped_by_authors.json"
|
||||
|
||||
# Add `GH_` prefix as keyword mention
|
||||
output = {}
|
||||
for author, item in new_data.items():
|
||||
author = f"GH_{author}"
|
||||
output[author] = item
|
||||
|
||||
report = f"<{url}|New failed tests>\\n\\n"
|
||||
report += json.dumps(output, indent=4).replace('"', '\\"').replace("\n", "\\n")
|
||||
print(report)
|
||||
85
transformers/utils/process_circleci_workflow_test_reports.py
Normal file
85
transformers/utils/process_circleci_workflow_test_reports.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--workflow_id", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
workflow_id = args.workflow_id
|
||||
|
||||
r = requests.get(
|
||||
f"https://circleci.com/api/v2/workflow/{workflow_id}/job",
|
||||
headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")},
|
||||
)
|
||||
jobs = r.json()["items"]
|
||||
|
||||
os.makedirs("outputs", exist_ok=True)
|
||||
|
||||
workflow_summary = {}
|
||||
# for each job, download artifacts
|
||||
for job in jobs:
|
||||
project_slug = job["project_slug"]
|
||||
if job["name"].startswith(("tests_", "examples_", "pipelines_")):
|
||||
url = f"https://circleci.com/api/v2/project/{project_slug}/{job['job_number']}/artifacts"
|
||||
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
|
||||
job_artifacts = r.json()["items"]
|
||||
|
||||
os.makedirs(job["name"], exist_ok=True)
|
||||
os.makedirs(f"outputs/{job['name']}", exist_ok=True)
|
||||
|
||||
job_test_summaries = {}
|
||||
for artifact in job_artifacts:
|
||||
if artifact["path"].startswith("reports/") and artifact["path"].endswith("/summary_short.txt"):
|
||||
node_index = artifact["node_index"]
|
||||
url = artifact["url"]
|
||||
r = requests.get(url, headers={"Circle-Token": os.environ.get("CIRCLE_TOKEN", "")})
|
||||
test_summary = r.text
|
||||
job_test_summaries[node_index] = test_summary
|
||||
|
||||
summary = {}
|
||||
for node_index, node_test_summary in job_test_summaries.items():
|
||||
for line in node_test_summary.splitlines():
|
||||
if line.startswith("PASSED "):
|
||||
test = line[len("PASSED ") :]
|
||||
summary[test] = "passed"
|
||||
elif line.startswith("FAILED "):
|
||||
test = line[len("FAILED ") :].split()[0]
|
||||
summary[test] = "failed"
|
||||
# failed before passed
|
||||
summary = dict(sorted(summary.items(), key=lambda x: (x[1], x[0])))
|
||||
workflow_summary[job["name"]] = summary
|
||||
|
||||
# collected version
|
||||
with open(f"outputs/{job['name']}/test_summary.json", "w") as fp:
|
||||
json.dump(summary, fp, indent=4)
|
||||
|
||||
new_workflow_summary = {}
|
||||
for job_name, job_summary in workflow_summary.items():
|
||||
for test, status in job_summary.items():
|
||||
if test not in new_workflow_summary:
|
||||
new_workflow_summary[test] = {}
|
||||
new_workflow_summary[test][job_name] = status
|
||||
|
||||
for test, result in new_workflow_summary.items():
|
||||
new_workflow_summary[test] = dict(sorted(result.items()))
|
||||
new_workflow_summary = dict(sorted(new_workflow_summary.items()))
|
||||
|
||||
with open("outputs/test_summary.json", "w") as fp:
|
||||
json.dump(new_workflow_summary, fp, indent=4)
|
||||
75
transformers/utils/process_test_artifacts.py
Normal file
75
transformers/utils/process_test_artifacts.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
This helper computes the "ideal" number of nodes to use in circle CI.
|
||||
For each job, we compute this parameter and pass it to the `generated_config.yaml`.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
|
||||
|
||||
MAX_PARALLEL_NODES = 8 # TODO create a mapping!
|
||||
AVERAGE_TESTS_PER_NODES = 5
|
||||
|
||||
|
||||
def count_lines(filepath):
|
||||
"""Count the number of lines in a file."""
|
||||
try:
|
||||
with open(filepath, "r") as f:
|
||||
return len(f.read().split("\n"))
|
||||
except FileNotFoundError:
|
||||
return 0
|
||||
|
||||
|
||||
def compute_parallel_nodes(line_count, max_tests_per_node=10):
|
||||
"""Compute the number of parallel nodes required."""
|
||||
num_nodes = math.ceil(line_count / AVERAGE_TESTS_PER_NODES)
|
||||
if line_count < 4:
|
||||
return 1
|
||||
return min(MAX_PARALLEL_NODES, num_nodes)
|
||||
|
||||
|
||||
def process_artifacts(input_file, output_file):
|
||||
# Read the JSON data from the input file
|
||||
with open(input_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Process items and build the new JSON structure
|
||||
transformed_data = {}
|
||||
for item in data.get("items", []):
|
||||
if "test_list" in item["path"]:
|
||||
key = os.path.splitext(os.path.basename(item["path"]))[0]
|
||||
transformed_data[key] = item["url"]
|
||||
parallel_key = key.split("_test")[0] + "_parallelism"
|
||||
file_path = os.path.join("test_preparation", f"{key}.txt")
|
||||
line_count = count_lines(file_path)
|
||||
transformed_data[parallel_key] = compute_parallel_nodes(line_count)
|
||||
|
||||
# Remove the "generated_config" key if it exists
|
||||
if "generated_config" in transformed_data:
|
||||
del transformed_data["generated_config"]
|
||||
|
||||
# Write the transformed data to the output file
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(transformed_data, f, indent=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_file = "test_preparation/artifacts.json"
|
||||
output_file = "test_preparation/transformed_artifacts.json"
|
||||
process_artifacts(input_file, output_file)
|
||||
218
transformers/utils/release.py
Normal file
218
transformers/utils/release.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that prepares the repository for releases (or patches) by updating all versions in the relevant places. It
|
||||
also performs some post-release cleanup, by updating the links in the main README to respective model doc pages (from
|
||||
main to stable).
|
||||
|
||||
To prepare for a release, use from the root of the repo on the release branch with:
|
||||
|
||||
```bash
|
||||
python release.py
|
||||
```
|
||||
|
||||
or use `make pre-release`.
|
||||
|
||||
To prepare for a patch release, use from the root of the repo on the release branch with:
|
||||
|
||||
```bash
|
||||
python release.py --patch
|
||||
```
|
||||
|
||||
or use `make pre-patch`.
|
||||
|
||||
To do the post-release cleanup, use from the root of the repo on the main branch with:
|
||||
|
||||
```bash
|
||||
python release.py --post_release
|
||||
```
|
||||
|
||||
or use `make post-release`.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import packaging.version
|
||||
|
||||
|
||||
# All paths are defined with the intent that this script should be run from the root of the repo.
|
||||
PATH_TO_EXAMPLES = "examples/"
|
||||
PATH_TO_MODELS = "src/transformers/models"
|
||||
# This maps a type of file to the pattern to look for when searching where the version is defined, as well as the
|
||||
# template to follow when replacing it with the new version.
|
||||
REPLACE_PATTERNS = {
|
||||
"examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'),
|
||||
"init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
|
||||
"setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
|
||||
"uv_script_release": (
|
||||
re.compile(r'^# "transformers(\[.+\])?.*$', re.MULTILINE),
|
||||
r'# "transformers\g<1>==VERSION",',
|
||||
),
|
||||
"uv_script_dev": (
|
||||
re.compile(r'^# "transformers(\[.+\])?.*$', re.MULTILINE),
|
||||
r'# "transformers\g<1> @ git+https://github.com/huggingface/transformers.git",',
|
||||
),
|
||||
}
|
||||
# This maps a type of file to its path in Transformers
|
||||
REPLACE_FILES = {
|
||||
"init": "src/transformers/__init__.py",
|
||||
"setup": "setup.py",
|
||||
}
|
||||
README_FILE = "README.md"
|
||||
UV_SCRIPT_MARKER = "# /// script"
|
||||
|
||||
|
||||
def update_version_in_file(fname: str, version: str, file_type: str):
|
||||
"""
|
||||
Update the version of Transformers in one file.
|
||||
|
||||
Args:
|
||||
fname (`str`): The path to the file where we want to update the version.
|
||||
version (`str`): The new version to set in the file.
|
||||
file_type (`str`): The type of the file (should be a key in `REPLACE_PATTERNS`).
|
||||
"""
|
||||
with open(fname, "r", encoding="utf-8", newline="\n") as f:
|
||||
code = f.read()
|
||||
re_pattern, replace = REPLACE_PATTERNS[file_type]
|
||||
replace = replace.replace("VERSION", version)
|
||||
code = re_pattern.sub(replace, code)
|
||||
with open(fname, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(code)
|
||||
|
||||
|
||||
def update_version_in_examples(version: str, patch: bool = False):
|
||||
"""
|
||||
Update the version in all examples files.
|
||||
|
||||
Args:
|
||||
version (`str`): The new version to set in the examples.
|
||||
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||
"""
|
||||
for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES):
|
||||
# Removing some of the folders with non-actively maintained examples from the walk
|
||||
if "legacy" in directories:
|
||||
directories.remove("legacy")
|
||||
for fname in fnames:
|
||||
if fname.endswith(".py"):
|
||||
if UV_SCRIPT_MARKER in Path(folder, fname).read_text():
|
||||
# Update the dependencies in UV scripts
|
||||
uv_script_file_type = "uv_script_dev" if ".dev" in version else "uv_script_release"
|
||||
update_version_in_file(os.path.join(folder, fname), version, file_type=uv_script_file_type)
|
||||
if not patch:
|
||||
# We don't update the version in the examples for patch releases.
|
||||
update_version_in_file(os.path.join(folder, fname), version, file_type="examples")
|
||||
|
||||
|
||||
def global_version_update(version: str, patch: bool = False):
|
||||
"""
|
||||
Update the version in all needed files.
|
||||
|
||||
Args:
|
||||
version (`str`): The new version to set everywhere.
|
||||
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||
"""
|
||||
for pattern, fname in REPLACE_FILES.items():
|
||||
update_version_in_file(fname, version, pattern)
|
||||
update_version_in_examples(version, patch=patch)
|
||||
|
||||
|
||||
def remove_conversion_scripts():
|
||||
"""
|
||||
Delete the scripts that convert models from older, unsupported formats. We don't want to include these
|
||||
in release wheels because they often have to open insecure file types (pickle, Torch .bin models). This results in
|
||||
vulnerability scanners flagging us and can cause compliance issues for users with strict security policies.
|
||||
"""
|
||||
model_dir = Path(PATH_TO_MODELS)
|
||||
for conversion_script in list(model_dir.glob("**/convert*.py")):
|
||||
conversion_script.unlink()
|
||||
|
||||
|
||||
def get_version() -> packaging.version.Version:
|
||||
"""
|
||||
Reads the current version in the main __init__.
|
||||
"""
|
||||
with open(REPLACE_FILES["init"], "r") as f:
|
||||
code = f.read()
|
||||
default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
|
||||
return packaging.version.parse(default_version)
|
||||
|
||||
|
||||
def pre_release_work(patch: bool = False):
|
||||
"""
|
||||
Do all the necessary pre-release steps:
|
||||
- figure out the next minor release version and ask confirmation
|
||||
- update the version everywhere
|
||||
- clean-up the model list in the main README
|
||||
|
||||
Args:
|
||||
patch (`bool`, *optional*, defaults to `False`): Whether or not this is a patch release.
|
||||
"""
|
||||
# First let's get the default version: base version if we are in dev, bump minor otherwise.
|
||||
default_version = get_version()
|
||||
if patch and default_version.is_devrelease:
|
||||
raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
|
||||
if default_version.is_devrelease:
|
||||
default_version = default_version.base_version
|
||||
elif patch:
|
||||
default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
|
||||
else:
|
||||
default_version = f"{default_version.major}.{default_version.minor + 1}.0"
|
||||
|
||||
# Now let's ask nicely if we have found the right version.
|
||||
version = input(f"Which version are you releasing? [{default_version}]")
|
||||
if len(version) == 0:
|
||||
version = default_version
|
||||
|
||||
print(f"Updating version to {version}.")
|
||||
global_version_update(version, patch=patch)
|
||||
print("Deleting conversion scripts.")
|
||||
remove_conversion_scripts()
|
||||
|
||||
|
||||
def post_release_work():
|
||||
"""
|
||||
Do all the necessary post-release steps:
|
||||
- figure out the next dev version and ask confirmation
|
||||
- update the version everywhere
|
||||
- clean-up the model list in the main README
|
||||
"""
|
||||
# First let's get the current version
|
||||
current_version = get_version()
|
||||
dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
|
||||
current_version = current_version.base_version
|
||||
|
||||
# Check with the user we got that right.
|
||||
version = input(f"Which version are we developing now? [{dev_version}]")
|
||||
if len(version) == 0:
|
||||
version = dev_version
|
||||
|
||||
print(f"Updating version to {version}.")
|
||||
global_version_update(version)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
|
||||
parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
|
||||
args = parser.parse_args()
|
||||
if not args.post_release:
|
||||
pre_release_work(patch=args.patch)
|
||||
elif args.patch:
|
||||
print("Nothing to do after a patch :-)")
|
||||
else:
|
||||
post_release_work()
|
||||
199
transformers/utils/scan_skipped_tests.py
Normal file
199
transformers/utils/scan_skipped_tests.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
REPO_ROOT = Path().cwd()
|
||||
|
||||
COMMON_TEST_FILES: list[tuple[Path, str]] = [
|
||||
(REPO_ROOT / "tests/test_modeling_common.py", "common"),
|
||||
(REPO_ROOT / "tests/generation/test_utils.py", "GenerationMixin"),
|
||||
]
|
||||
|
||||
MODELS_DIR = REPO_ROOT / "tests/models"
|
||||
|
||||
|
||||
def get_common_tests(file_paths_with_origin: list[tuple[Path, str]]) -> dict[str, str]:
|
||||
"""Extract all common test function names (e.g., 'test_forward')."""
|
||||
tests_with_origin: dict[str, str] = {}
|
||||
for file_path, origin_tag in file_paths_with_origin:
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
for test_name in re.findall(r"^\s*def\s+(test_[A-Za-z0-9_]+)", content, re.MULTILINE):
|
||||
tests_with_origin[test_name] = origin_tag
|
||||
return tests_with_origin
|
||||
|
||||
|
||||
def get_models_and_test_files(models_dir: Path) -> tuple[list[str], list[Path]]:
|
||||
if not models_dir.is_dir():
|
||||
raise FileNotFoundError(f"Models directory not found at {models_dir}")
|
||||
test_files: list[Path] = sorted(models_dir.rglob("test_modeling_*.py"))
|
||||
model_names: list[str] = sorted({file_path.parent.name for file_path in test_files})
|
||||
return model_names, test_files
|
||||
|
||||
|
||||
def _extract_reason_from_decorators(decorators_block: str) -> str:
|
||||
"""Extracts the reason string from a decorator block, if any."""
|
||||
reason_match = re.search(r'reason\s*=\s*["\'](.*?)["\']', decorators_block)
|
||||
if reason_match:
|
||||
return reason_match.group(1)
|
||||
reason_match = re.search(r'\((?:.*?,\s*)?["\'](.*?)["\']\)', decorators_block)
|
||||
if reason_match:
|
||||
return reason_match.group(1)
|
||||
return decorators_block.strip().split("\n")[-1].strip()
|
||||
|
||||
|
||||
def extract_test_info(file_content: str) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
Parse a test file once and return a mapping of test functions to their
|
||||
status and skip reason, e.g. {'test_forward': ('SKIPPED', 'too slow')}.
|
||||
"""
|
||||
result: dict[str, tuple[str, str]] = {}
|
||||
pattern = re.compile(r"((?:^\s*@.*?\n)*?)^\s*def\s+(test_[A-Za-z0-9_]+)\b", re.MULTILINE)
|
||||
for decorators_block, test_name in pattern.findall(file_content):
|
||||
if "skip" in decorators_block:
|
||||
result[test_name] = ("SKIPPED", _extract_reason_from_decorators(decorators_block))
|
||||
else:
|
||||
result[test_name] = ("RAN", "")
|
||||
return result
|
||||
|
||||
|
||||
def build_model_overrides(model_test_files: list[Path]) -> dict[str, dict[str, tuple[str, str]]]:
|
||||
"""Return *model_name → {test_name → (status, reason)}* mapping."""
|
||||
model_overrides: dict[str, dict[str, tuple[str, str]]] = {}
|
||||
for file_path in model_test_files:
|
||||
model_name = file_path.parent.name
|
||||
file_content = file_path.read_text(encoding="utf-8")
|
||||
model_overrides.setdefault(model_name, {}).update(extract_test_info(file_content))
|
||||
return model_overrides
|
||||
|
||||
|
||||
def save_json(obj: dict, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def summarize_single_test(
|
||||
test_name: str,
|
||||
model_names: list[str],
|
||||
model_overrides: dict[str, dict[str, tuple[str, str]]],
|
||||
) -> dict[str, object]:
|
||||
"""Print a concise terminal summary for *test_name* and return the raw data."""
|
||||
models_ran, models_skipped, reasons_for_skipping = [], [], []
|
||||
for model_name in model_names:
|
||||
status, reason = model_overrides.get(model_name, {}).get(test_name, ("RAN", ""))
|
||||
if status == "SKIPPED":
|
||||
models_skipped.append(model_name)
|
||||
reasons_for_skipping.append(f"{model_name}: {reason}")
|
||||
else:
|
||||
models_ran.append(model_name)
|
||||
|
||||
total_models = len(model_names)
|
||||
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
|
||||
|
||||
print(f"\n== {test_name} ==")
|
||||
print(f"Ran : {len(models_ran)}/{total_models}")
|
||||
print(f"Skipped : {len(models_skipped)}/{total_models} ({skipped_ratio:.1%})")
|
||||
for reason_entry in reasons_for_skipping[:10]:
|
||||
print(f" - {reason_entry}")
|
||||
if len(reasons_for_skipping) > 10:
|
||||
print(" - ...")
|
||||
|
||||
return {
|
||||
"models_ran": sorted(models_ran),
|
||||
"models_skipped": sorted(models_skipped),
|
||||
"skipped_proportion": round(skipped_ratio, 4),
|
||||
"reasons_skipped": sorted(reasons_for_skipping),
|
||||
}
|
||||
|
||||
|
||||
def summarize_all_tests(
|
||||
tests_with_origin: dict[str, str],
|
||||
model_names: list[str],
|
||||
model_overrides: dict[str, dict[str, tuple[str, str]]],
|
||||
) -> dict[str, object]:
|
||||
"""Return aggregated data for every discovered common test."""
|
||||
results: dict[str, object] = {}
|
||||
total_models = len(model_names)
|
||||
test_names = list(tests_with_origin)
|
||||
|
||||
print(f"📝 Aggregating {len(test_names)} tests...")
|
||||
for index, test_fn in enumerate(test_names, 1):
|
||||
print(f" ({index}/{len(test_names)}) {test_fn}", end="\r")
|
||||
models_ran, models_skipped, reasons_for_skipping = [], [], []
|
||||
for model_name in model_names:
|
||||
status, reason = model_overrides.get(model_name, {}).get(test_fn, ("RAN", ""))
|
||||
if status == "SKIPPED":
|
||||
models_skipped.append(model_name)
|
||||
reasons_for_skipping.append(f"{model_name}: {reason}")
|
||||
else:
|
||||
models_ran.append(model_name)
|
||||
|
||||
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
|
||||
results[test_fn] = {
|
||||
"origin": tests_with_origin[test_fn],
|
||||
"models_ran": sorted(models_ran),
|
||||
"models_skipped": sorted(models_skipped),
|
||||
"skipped_proportion": round(skipped_ratio, 4),
|
||||
"reasons_skipped": sorted(reasons_for_skipping),
|
||||
}
|
||||
print("\n✅ Scan complete.")
|
||||
return results
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Scan model tests for overridden or skipped common or generate tests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=".",
|
||||
help="Directory for JSON output (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_method_name",
|
||||
help="Scan only this test method (single‑test mode)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_dir = Path(args.output_dir).expanduser()
|
||||
test_method_name = args.test_method_name
|
||||
|
||||
tests_with_origin = get_common_tests(COMMON_TEST_FILES)
|
||||
if test_method_name:
|
||||
tests_with_origin = {test_method_name: tests_with_origin.get(test_method_name, "unknown")}
|
||||
|
||||
model_names, model_test_files = get_models_and_test_files(MODELS_DIR)
|
||||
print(f"🔬 Parsing {len(model_test_files)} model test files once each...")
|
||||
model_overrides = build_model_overrides(model_test_files)
|
||||
|
||||
if test_method_name:
|
||||
data = summarize_single_test(test_method_name, model_names, model_overrides)
|
||||
json_path = output_dir / f"scan_{test_method_name}.json"
|
||||
else:
|
||||
data = summarize_all_tests(tests_with_origin, model_names, model_overrides)
|
||||
json_path = output_dir / "all_tests_scan_result.json"
|
||||
save_json(data, json_path)
|
||||
print(f"\n📄 JSON saved to {json_path.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
26
transformers/utils/set_cuda_devices_for_ci.py
Normal file
26
transformers/utils/set_cuda_devices_for_ci.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""A simple script to set flexibly CUDA_VISIBLE_DEVICES in GitHub Actions CI workflow files."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--test_folder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The test folder name of the model being tested. For example, `models/cohere`.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# `test_eager_matches_sdpa_generate` for `cohere` needs a lot of GPU memory!
|
||||
# This depends on the runners. At this moment we are targeting our AWS CI runners.
|
||||
if args.test_folder == "models/cohere":
|
||||
cuda_visible_devices = "0,1,2,3"
|
||||
elif "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||
else:
|
||||
cuda_visible_devices = "0"
|
||||
|
||||
print(cuda_visible_devices)
|
||||
15
transformers/utils/slow_documentation_tests.txt
Normal file
15
transformers/utils/slow_documentation_tests.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
docs/source/en/generation_strategies.md
|
||||
docs/source/en/model_doc/code_llama.md
|
||||
docs/source/en/model_doc/ctrl.md
|
||||
docs/source/en/model_doc/kosmos-2.md
|
||||
docs/source/en/model_doc/seamless_m4t.md
|
||||
docs/source/en/model_doc/seamless_m4t_v2.md
|
||||
docs/source/en/tasks/prompting.md
|
||||
docs/source/ja/model_doc/code_llama.md
|
||||
src/transformers/models/blip_2/modeling_blip_2.py
|
||||
src/transformers/models/ctrl/modeling_ctrl.py
|
||||
src/transformers/models/fuyu/modeling_fuyu.py
|
||||
src/transformers/models/idefics2/modeling_idefics2.py
|
||||
src/transformers/models/kosmos2/modeling_kosmos2.py
|
||||
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
|
||||
src/transformers/models/musicgen_melody/processing_musicgen_melody.py
|
||||
125
transformers/utils/sort_auto_mappings.py
Normal file
125
transformers/utils/sort_auto_mappings.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that sorts the names in the auto mappings defines in the auto modules in alphabetical order.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/sort_auto_mappings.py
|
||||
```
|
||||
|
||||
to auto-fix all the auto mappings (used in `make style`).
|
||||
|
||||
To only check if the mappings are properly sorted (as used in `make quality`), do:
|
||||
|
||||
```bash
|
||||
python utils/sort_auto_mappings.py --check_only
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Path are set with the intent you should run this script from the root of the repo.
|
||||
PATH_TO_AUTO_MODULE = "src/transformers/models/auto"
|
||||
|
||||
|
||||
# re pattern that matches mapping introductions:
|
||||
# SUPER_MODEL_MAPPING_NAMES = OrderedDict or SUPER_MODEL_MAPPING = OrderedDict
|
||||
_re_intro_mapping = re.compile(r"[A-Z_]+_MAPPING(\s+|_[A-Z_]+\s+)=\s+OrderedDict")
|
||||
# re pattern that matches identifiers in mappings
|
||||
_re_identifier = re.compile(r'\s*\(\s*"(\S[^"]+)"')
|
||||
|
||||
|
||||
def sort_auto_mapping(fname: str, overwrite: bool = False) -> Optional[bool]:
|
||||
"""
|
||||
Sort all auto mappings in a file.
|
||||
|
||||
Args:
|
||||
fname (`str`): The name of the file where we want to sort auto-mappings.
|
||||
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
|
||||
|
||||
Returns:
|
||||
`Optional[bool]`: Returns `None` if `overwrite=True`. Otherwise returns `True` if the file has an auto-mapping
|
||||
improperly sorted, `False` if the file is okay.
|
||||
"""
|
||||
with open(fname, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
line_idx = 0
|
||||
while line_idx < len(lines):
|
||||
if _re_intro_mapping.search(lines[line_idx]) is not None:
|
||||
# Start of a new mapping!
|
||||
indent = len(re.search(r"^(\s*)\S", lines[line_idx]).groups()[0]) + 8
|
||||
while not lines[line_idx].startswith(" " * indent + "("):
|
||||
new_lines.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
blocks = []
|
||||
while lines[line_idx].strip() != "]":
|
||||
# Blocks either fit in one line or not
|
||||
if lines[line_idx].strip() == "(":
|
||||
start_idx = line_idx
|
||||
while not lines[line_idx].startswith(" " * indent + ")"):
|
||||
line_idx += 1
|
||||
blocks.append("\n".join(lines[start_idx : line_idx + 1]))
|
||||
else:
|
||||
blocks.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
# Sort blocks by their identifiers
|
||||
blocks = sorted(blocks, key=lambda x: _re_identifier.search(x).groups()[0])
|
||||
new_lines += blocks
|
||||
else:
|
||||
new_lines.append(lines[line_idx])
|
||||
line_idx += 1
|
||||
|
||||
if overwrite:
|
||||
with open(fname, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(new_lines))
|
||||
else:
|
||||
return "\n".join(new_lines) != content
|
||||
|
||||
|
||||
def sort_all_auto_mappings(overwrite: bool = False):
|
||||
"""
|
||||
Sort all auto mappings in the library.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`): Whether or not to fix and overwrite the file.
|
||||
"""
|
||||
fnames = [os.path.join(PATH_TO_AUTO_MODULE, f) for f in os.listdir(PATH_TO_AUTO_MODULE) if f.endswith(".py")]
|
||||
diffs = [sort_auto_mapping(fname, overwrite=overwrite) for fname in fnames]
|
||||
|
||||
if not overwrite and any(diffs):
|
||||
failures = [f for f, d in zip(fnames, diffs) if d]
|
||||
raise ValueError(
|
||||
f"The following files have auto mappings that need sorting: {', '.join(failures)}. Run `make style` to fix"
|
||||
" this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.")
|
||||
args = parser.parse_args()
|
||||
|
||||
sort_all_auto_mappings(not args.check_only)
|
||||
98
transformers/utils/split_doctest_jobs.py
Normal file
98
transformers/utils/split_doctest_jobs.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to get the files against which we will run doc testing.
|
||||
This uses `tests_fetcher.get_all_doctest_files` then groups the test files by their directory paths.
|
||||
|
||||
The files in `docs/source/en/model_doc` or `docs/source/en/tasks` are **NOT** grouped together with other files in the
|
||||
same directory: the objective is to run doctest against them in independent GitHub Actions jobs.
|
||||
|
||||
Assume we are under `transformers` root directory:
|
||||
To get a map (dictionary) between directory (or file) paths and the corresponding files
|
||||
```bash
|
||||
python utils/split_doctest_jobs.py
|
||||
```
|
||||
or to get a list of lists of directory (or file) paths
|
||||
```bash
|
||||
python utils/split_doctest_jobs.py --only_return_keys --num_splits 4
|
||||
```
|
||||
(this is used to allow GitHub Actions to generate more than 256 jobs using matrix)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from tests_fetcher import get_all_doctest_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--only_return_keys",
|
||||
action="store_true",
|
||||
help="if to only return the keys (which is a list of list of files' directory or file paths).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_splits",
|
||||
type=int,
|
||||
default=1,
|
||||
help="the number of splits into which the (flat) list of directory/file paths will be split. This has effect only if `only_return_keys` is `True`.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
all_doctest_files = get_all_doctest_files()
|
||||
|
||||
raw_test_collection_map = defaultdict(list)
|
||||
|
||||
for file in all_doctest_files:
|
||||
file_dir = "/".join(Path(file).parents[0].parts)
|
||||
|
||||
# not to run files in `src/` for now as it is completely broken at this moment. See issues/39159 and
|
||||
# https://github.com/huggingface/transformers/actions/runs/15988670157
|
||||
# TODO (ydshieh): fix the error, ideally before 2025/09
|
||||
if file_dir.startswith("src/"):
|
||||
continue
|
||||
|
||||
raw_test_collection_map[file_dir].append(file)
|
||||
|
||||
refined_test_collection_map = {}
|
||||
for file_dir in raw_test_collection_map:
|
||||
if file_dir in ["docs/source/en/model_doc", "docs/source/en/tasks"]:
|
||||
for file in raw_test_collection_map[file_dir]:
|
||||
refined_test_collection_map[file] = file
|
||||
else:
|
||||
refined_test_collection_map[file_dir] = " ".join(sorted(raw_test_collection_map[file_dir]))
|
||||
|
||||
sorted_file_dirs = sorted(refined_test_collection_map.keys())
|
||||
|
||||
test_collection_map = {}
|
||||
for file_dir in sorted_file_dirs:
|
||||
test_collection_map[file_dir] = refined_test_collection_map[file_dir]
|
||||
|
||||
num_jobs = len(test_collection_map)
|
||||
num_jobs_per_splits = num_jobs // args.num_splits
|
||||
|
||||
file_directory_splits = []
|
||||
end = 0
|
||||
for idx in range(args.num_splits):
|
||||
start = end
|
||||
end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)
|
||||
file_directory_splits.append(sorted_file_dirs[start:end])
|
||||
|
||||
if args.only_return_keys:
|
||||
print(file_directory_splits)
|
||||
else:
|
||||
print(dict(test_collection_map))
|
||||
77
transformers/utils/split_model_tests.py
Normal file
77
transformers/utils/split_model_tests.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is used to get the list of folders under `tests/models` and split the list into `NUM_SLICES` splits.
|
||||
The main use case is a GitHub Actions workflow file calling this script to get the (nested) list of folders allowing it
|
||||
to split the list of jobs to run into multiple slices each containing a smaller number of jobs. This way, we can bypass
|
||||
the maximum of 256 jobs in a matrix.
|
||||
|
||||
See the `setup` and `run_models_gpu` jobs defined in the workflow file `.github/workflows/self-scheduled.yml` for more
|
||||
details.
|
||||
|
||||
Usage:
|
||||
|
||||
This script is required to be run under `tests` folder of `transformers` root directory.
|
||||
|
||||
Assume we are under `transformers` root directory:
|
||||
```bash
|
||||
cd tests
|
||||
python ../utils/split_model_tests.py --num_splits 64
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import os
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
default="",
|
||||
help="the list of pre-computed model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_splits",
|
||||
type=int,
|
||||
default=1,
|
||||
help="the number of splits into which the (flat) list of folders will be split.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
tests = os.getcwd()
|
||||
model_tests = os.listdir(os.path.join(tests, "models"))
|
||||
d1 = sorted(filter(os.path.isdir, os.listdir(tests)))
|
||||
d2 = sorted(filter(os.path.isdir, [f"models/{x}" for x in model_tests]))
|
||||
d1.remove("models")
|
||||
d = d2 + d1
|
||||
|
||||
if args.models != "":
|
||||
model_tests = ast.literal_eval(args.models)
|
||||
d = sorted(filter(os.path.isdir, [f"models/{x}" for x in model_tests]))
|
||||
|
||||
num_jobs = len(d)
|
||||
num_jobs_per_splits = num_jobs // args.num_splits
|
||||
|
||||
model_splits = []
|
||||
end = 0
|
||||
for idx in range(args.num_splits):
|
||||
start = end
|
||||
end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)
|
||||
model_splits.append(d[start:end])
|
||||
|
||||
print(model_splits)
|
||||
0
transformers/utils/test_module/__init__.py
Normal file
0
transformers/utils/test_module/__init__.py
Normal file
9
transformers/utils/test_module/custom_configuration.py
Normal file
9
transformers/utils/test_module/custom_configuration.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class CustomConfig(PretrainedConfig):
|
||||
model_type = "custom"
|
||||
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
@@ -0,0 +1,5 @@
|
||||
from transformers import Wav2Vec2FeatureExtractor
|
||||
|
||||
|
||||
class CustomFeatureExtractor(Wav2Vec2FeatureExtractor):
|
||||
pass
|
||||
@@ -0,0 +1,5 @@
|
||||
from transformers import CLIPImageProcessor
|
||||
|
||||
|
||||
class CustomImageProcessor(CLIPImageProcessor):
|
||||
pass
|
||||
19
transformers/utils/test_module/custom_modeling.py
Normal file
19
transformers/utils/test_module/custom_modeling.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .custom_configuration import CustomConfig
|
||||
|
||||
|
||||
class CustomModel(PreTrainedModel):
|
||||
config_class = CustomConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
33
transformers/utils/test_module/custom_pipeline.py
Normal file
33
transformers/utils/test_module/custom_pipeline.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import numpy as np
|
||||
|
||||
from transformers import Pipeline
|
||||
|
||||
|
||||
def softmax(outputs):
|
||||
maxes = np.max(outputs, axis=-1, keepdims=True)
|
||||
shifted_exp = np.exp(outputs - maxes)
|
||||
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
||||
|
||||
|
||||
class PairClassificationPipeline(Pipeline):
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
preprocess_kwargs = {}
|
||||
if "second_text" in kwargs:
|
||||
preprocess_kwargs["second_text"] = kwargs["second_text"]
|
||||
return preprocess_kwargs, {}, {}
|
||||
|
||||
def preprocess(self, text, second_text=None):
|
||||
return self.tokenizer(text, text_pair=second_text, return_tensors="pt")
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
return self.model(**model_inputs)
|
||||
|
||||
def postprocess(self, model_outputs):
|
||||
logits = model_outputs.logits[0].numpy()
|
||||
probabilities = softmax(logits)
|
||||
|
||||
best_class = np.argmax(probabilities)
|
||||
label = self.model.config.id2label[best_class]
|
||||
score = probabilities[best_class].item()
|
||||
logits = logits.tolist()
|
||||
return {"label": label, "score": score, "logits": logits}
|
||||
6
transformers/utils/test_module/custom_processing.py
Normal file
6
transformers/utils/test_module/custom_processing.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
|
||||
class CustomProcessor(ProcessorMixin):
|
||||
feature_extractor_class = "AutoFeatureExtractor"
|
||||
tokenizer_class = "AutoTokenizer"
|
||||
5
transformers/utils/test_module/custom_tokenization.py
Normal file
5
transformers/utils/test_module/custom_tokenization.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
class CustomTokenizer(BertTokenizer):
|
||||
pass
|
||||
@@ -0,0 +1,8 @@
|
||||
from transformers import BertTokenizerFast
|
||||
|
||||
from .custom_tokenization import CustomTokenizer
|
||||
|
||||
|
||||
class CustomTokenizerFast(BertTokenizerFast):
|
||||
slow_tokenizer_class = CustomTokenizer
|
||||
pass
|
||||
@@ -0,0 +1,5 @@
|
||||
from transformers import LlavaOnevisionVideoProcessor
|
||||
|
||||
|
||||
class CustomVideoProcessor(LlavaOnevisionVideoProcessor):
|
||||
pass
|
||||
1189
transformers/utils/tests_fetcher.py
Normal file
1189
transformers/utils/tests_fetcher.py
Normal file
File diff suppressed because it is too large
Load Diff
245
transformers/utils/tf_ops/onnx.json
Normal file
245
transformers/utils/tf_ops/onnx.json
Normal file
@@ -0,0 +1,245 @@
|
||||
{
|
||||
"opsets": {
|
||||
"1": [
|
||||
"Abs",
|
||||
"Add",
|
||||
"AddV2",
|
||||
"ArgMax",
|
||||
"ArgMin",
|
||||
"AvgPool",
|
||||
"AvgPool3D",
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2",
|
||||
"BatchToSpaceND",
|
||||
"BiasAdd",
|
||||
"BiasAddV1",
|
||||
"Cast",
|
||||
"Ceil",
|
||||
"CheckNumerics",
|
||||
"ComplexAbs",
|
||||
"Concat",
|
||||
"ConcatV2",
|
||||
"Const",
|
||||
"ConstV2",
|
||||
"Conv1D",
|
||||
"Conv2D",
|
||||
"Conv2DBackpropInput",
|
||||
"Conv3D",
|
||||
"Conv3DBackpropInputV2",
|
||||
"DepthToSpace",
|
||||
"DepthwiseConv2d",
|
||||
"DepthwiseConv2dNative",
|
||||
"Div",
|
||||
"Dropout",
|
||||
"Elu",
|
||||
"Equal",
|
||||
"Erf",
|
||||
"Exp",
|
||||
"ExpandDims",
|
||||
"Flatten",
|
||||
"Floor",
|
||||
"Gather",
|
||||
"GatherNd",
|
||||
"GatherV2",
|
||||
"Greater",
|
||||
"Identity",
|
||||
"IdentityN",
|
||||
"If",
|
||||
"LRN",
|
||||
"LSTMBlockCell",
|
||||
"LeakyRelu",
|
||||
"Less",
|
||||
"Log",
|
||||
"LogSoftmax",
|
||||
"LogicalAnd",
|
||||
"LogicalNot",
|
||||
"LogicalOr",
|
||||
"LookupTableSizeV2",
|
||||
"MatMul",
|
||||
"Max",
|
||||
"MaxPool",
|
||||
"MaxPool3D",
|
||||
"MaxPoolV2",
|
||||
"Maximum",
|
||||
"Mean",
|
||||
"Min",
|
||||
"Minimum",
|
||||
"MirrorPad",
|
||||
"Mul",
|
||||
"Neg",
|
||||
"NoOp",
|
||||
"NotEqual",
|
||||
"OneHot",
|
||||
"Pack",
|
||||
"Pad",
|
||||
"PadV2",
|
||||
"Placeholder",
|
||||
"PlaceholderV2",
|
||||
"PlaceholderWithDefault",
|
||||
"Pow",
|
||||
"Prod",
|
||||
"RFFT",
|
||||
"RandomNormal",
|
||||
"RandomNormalLike",
|
||||
"RandomUniform",
|
||||
"RandomUniformLike",
|
||||
"RealDiv",
|
||||
"Reciprocal",
|
||||
"Relu",
|
||||
"Relu6",
|
||||
"Reshape",
|
||||
"Rsqrt",
|
||||
"Selu",
|
||||
"Shape",
|
||||
"Sigmoid",
|
||||
"Sign",
|
||||
"Size",
|
||||
"Slice",
|
||||
"Softmax",
|
||||
"Softplus",
|
||||
"Softsign",
|
||||
"SpaceToBatchND",
|
||||
"SpaceToDepth",
|
||||
"Split",
|
||||
"SplitV",
|
||||
"Sqrt",
|
||||
"Square",
|
||||
"SquaredDifference",
|
||||
"Squeeze",
|
||||
"StatelessIf",
|
||||
"StopGradient",
|
||||
"StridedSlice",
|
||||
"StringJoin",
|
||||
"Sub",
|
||||
"Sum",
|
||||
"Tanh",
|
||||
"Tile",
|
||||
"TopKV2",
|
||||
"Transpose",
|
||||
"TruncateDiv",
|
||||
"Unpack",
|
||||
"ZerosLike"
|
||||
],
|
||||
"2": [],
|
||||
"3": [],
|
||||
"4": [],
|
||||
"5": [],
|
||||
"6": [
|
||||
"AddN",
|
||||
"All",
|
||||
"Any",
|
||||
"FloorDiv",
|
||||
"FusedBatchNorm",
|
||||
"FusedBatchNormV2",
|
||||
"FusedBatchNormV3"
|
||||
],
|
||||
"7": [
|
||||
"Acos",
|
||||
"Asin",
|
||||
"Atan",
|
||||
"Cos",
|
||||
"Fill",
|
||||
"FloorMod",
|
||||
"GreaterEqual",
|
||||
"LessEqual",
|
||||
"Loop",
|
||||
"MatrixBandPart",
|
||||
"Multinomial",
|
||||
"Range",
|
||||
"ResizeBilinear",
|
||||
"ResizeNearestNeighbor",
|
||||
"Scan",
|
||||
"Select",
|
||||
"SelectV2",
|
||||
"Sin",
|
||||
"SoftmaxCrossEntropyWithLogits",
|
||||
"SparseSoftmaxCrossEntropyWithLogits",
|
||||
"StatelessWhile",
|
||||
"Tan",
|
||||
"TensorListFromTensor",
|
||||
"TensorListGetItem",
|
||||
"TensorListLength",
|
||||
"TensorListReserve",
|
||||
"TensorListResize",
|
||||
"TensorListSetItem",
|
||||
"TensorListStack",
|
||||
"While"
|
||||
],
|
||||
"8": [
|
||||
"BroadcastTo",
|
||||
"ClipByValue",
|
||||
"FIFOQueueV2",
|
||||
"HashTableV2",
|
||||
"IteratorGetNext",
|
||||
"IteratorV2",
|
||||
"LookupTableFindV2",
|
||||
"MaxPoolWithArgmax",
|
||||
"QueueDequeueManyV2",
|
||||
"QueueDequeueUpToV2",
|
||||
"QueueDequeueV2",
|
||||
"ReverseSequence"
|
||||
],
|
||||
"9": [
|
||||
"SegmentMax",
|
||||
"SegmentMean",
|
||||
"SegmentMin",
|
||||
"SegmentProd",
|
||||
"SegmentSum",
|
||||
"Sinh",
|
||||
"SparseSegmentMean",
|
||||
"SparseSegmentMeanWithNumSegments",
|
||||
"SparseSegmentSqrtN",
|
||||
"SparseSegmentSqrtNWithNumSegments",
|
||||
"SparseSegmentSum",
|
||||
"SparseSegmentSumWithNumSegments",
|
||||
"UnsortedSegmentMax",
|
||||
"UnsortedSegmentMin",
|
||||
"UnsortedSegmentProd",
|
||||
"UnsortedSegmentSum",
|
||||
"Where"
|
||||
],
|
||||
"10": [
|
||||
"CropAndResize",
|
||||
"CudnnRNN",
|
||||
"DynamicStitch",
|
||||
"FakeQuantWithMinMaxArgs",
|
||||
"IsFinite",
|
||||
"IsInf",
|
||||
"NonMaxSuppressionV2",
|
||||
"NonMaxSuppressionV3",
|
||||
"NonMaxSuppressionV4",
|
||||
"NonMaxSuppressionV5",
|
||||
"ParallelDynamicStitch",
|
||||
"ReverseV2",
|
||||
"Roll"
|
||||
],
|
||||
"11": [
|
||||
"Bincount",
|
||||
"Cumsum",
|
||||
"InvertPermutation",
|
||||
"LeftShift",
|
||||
"MatrixDeterminant",
|
||||
"MatrixDiagPart",
|
||||
"MatrixDiagPartV2",
|
||||
"MatrixDiagPartV3",
|
||||
"RaggedRange",
|
||||
"RightShift",
|
||||
"Round",
|
||||
"ScatterNd",
|
||||
"SparseFillEmptyRows",
|
||||
"SparseReshape",
|
||||
"SparseToDense",
|
||||
"TensorScatterUpdate",
|
||||
"Unique"
|
||||
],
|
||||
"12": [
|
||||
"Einsum",
|
||||
"MatrixDiag",
|
||||
"MatrixDiagV2",
|
||||
"MatrixDiagV3",
|
||||
"MatrixSetDiagV3",
|
||||
"SquaredDistance"
|
||||
],
|
||||
"13": []
|
||||
}
|
||||
}
|
||||
349
transformers/utils/update_metadata.py
Executable file
349
transformers/utils/update_metadata.py
Executable file
@@ -0,0 +1,349 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that updates the metadata of the Transformers library in the repository `huggingface/transformers-metadata`.
|
||||
|
||||
Usage for an update (as used by the GitHub action `update_metadata`):
|
||||
|
||||
```bash
|
||||
python utils/update_metadata.py --token <token> --commit_sha <commit_sha>
|
||||
```
|
||||
|
||||
Usage to check all pipelines are properly defined in the constant `PIPELINE_TAGS_AND_AUTO_MODELS` of this script, so
|
||||
that new pipelines are properly added as metadata (as used in `make repo-consistency`):
|
||||
|
||||
```bash
|
||||
python utils/update_metadata.py --check-only
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
|
||||
import pandas as pd
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import hf_hub_download, upload_folder
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/update_metadata.py
|
||||
TRANSFORMERS_PATH = "src/transformers"
|
||||
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||
|
||||
|
||||
# Regexes that match model names
|
||||
_re_pt_models = re.compile(r"(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration|ForRetrieval)")
|
||||
|
||||
|
||||
# Fill this with tuples (pipeline_tag, model_mapping, auto_model)
|
||||
PIPELINE_TAGS_AND_AUTO_MODELS = [
|
||||
("pretraining", "MODEL_FOR_PRETRAINING_MAPPING_NAMES", "AutoModelForPreTraining"),
|
||||
("feature-extraction", "MODEL_MAPPING_NAMES", "AutoModel"),
|
||||
("image-feature-extraction", "MODEL_FOR_IMAGE_MAPPING_NAMES", "AutoModel"),
|
||||
("audio-classification", "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForAudioClassification"),
|
||||
("text-generation", "MODEL_FOR_CAUSAL_LM_MAPPING_NAMES", "AutoModelForCausalLM"),
|
||||
("automatic-speech-recognition", "MODEL_FOR_CTC_MAPPING_NAMES", "AutoModelForCTC"),
|
||||
("image-classification", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForImageClassification"),
|
||||
("image-segmentation", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES", "AutoModelForImageSegmentation"),
|
||||
("image-text-to-text", "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES", "AutoModelForImageTextToText"),
|
||||
("image-to-image", "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES", "AutoModelForImageToImage"),
|
||||
("fill-mask", "MODEL_FOR_MASKED_LM_MAPPING_NAMES", "AutoModelForMaskedLM"),
|
||||
("object-detection", "MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES", "AutoModelForObjectDetection"),
|
||||
(
|
||||
"zero-shot-object-detection",
|
||||
"MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES",
|
||||
"AutoModelForZeroShotObjectDetection",
|
||||
),
|
||||
("question-answering", "MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES", "AutoModelForQuestionAnswering"),
|
||||
("text2text-generation", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES", "AutoModelForSeq2SeqLM"),
|
||||
("text-classification", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES", "AutoModelForSequenceClassification"),
|
||||
("automatic-speech-recognition", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES", "AutoModelForSpeechSeq2Seq"),
|
||||
(
|
||||
"table-question-answering",
|
||||
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForTableQuestionAnswering",
|
||||
),
|
||||
("token-classification", "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES", "AutoModelForTokenClassification"),
|
||||
("multiple-choice", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES", "AutoModelForMultipleChoice"),
|
||||
(
|
||||
"next-sentence-prediction",
|
||||
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES",
|
||||
"AutoModelForNextSentencePrediction",
|
||||
),
|
||||
(
|
||||
"audio-frame-classification",
|
||||
"MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModelForAudioFrameClassification",
|
||||
),
|
||||
("audio-xvector", "MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES", "AutoModelForAudioXVector"),
|
||||
(
|
||||
"document-question-answering",
|
||||
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
),
|
||||
(
|
||||
"visual-question-answering",
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
),
|
||||
("image-to-text", "MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
|
||||
(
|
||||
"zero-shot-image-classification",
|
||||
"MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES",
|
||||
"AutoModelForZeroShotImageClassification",
|
||||
),
|
||||
("depth-estimation", "MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES", "AutoModelForDepthEstimation"),
|
||||
("video-classification", "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING_NAMES", "AutoModelForVideoClassification"),
|
||||
("mask-generation", "MODEL_FOR_MASK_GENERATION_MAPPING_NAMES", "AutoModelForMaskGeneration"),
|
||||
("text-to-audio", "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES", "AutoModelForTextToSpectrogram"),
|
||||
("text-to-audio", "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES", "AutoModelForTextToWaveform"),
|
||||
("keypoint-matching", "MODEL_FOR_KEYPOINT_MATCHING_MAPPING_NAMES", "AutoModelForKeypointMatching"),
|
||||
]
|
||||
|
||||
|
||||
def camel_case_split(identifier: str) -> list[str]:
|
||||
"""
|
||||
Split a camel-cased name into words.
|
||||
|
||||
Args:
|
||||
identifier (`str`): The camel-cased name to parse.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of words in the identifier (as separated by capital letters).
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> camel_case_split("CamelCasedClass")
|
||||
["Camel", "Cased", "Class"]
|
||||
```
|
||||
"""
|
||||
# Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
|
||||
matches = re.finditer(".+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)", identifier)
|
||||
return [m.group(0) for m in matches]
|
||||
|
||||
|
||||
def get_frameworks_table() -> pd.DataFrame:
|
||||
"""
|
||||
Generates a dataframe containing the supported auto classes for each model type, using the content of the auto
|
||||
modules.
|
||||
"""
|
||||
# Dictionary model names to config.
|
||||
config_mapping_names = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING_NAMES
|
||||
model_prefix_to_model_type = {
|
||||
config.replace("Config", ""): model_type for model_type, config in config_mapping_names.items()
|
||||
}
|
||||
|
||||
pt_models = collections.defaultdict(bool)
|
||||
|
||||
# Let's lookup through all transformers object (once) and find if models are supported by a given backend.
|
||||
for attr_name in dir(transformers_module):
|
||||
lookup_dict = None
|
||||
if _re_pt_models.match(attr_name) is not None:
|
||||
lookup_dict = pt_models
|
||||
attr_name = _re_pt_models.match(attr_name).groups()[0]
|
||||
|
||||
if lookup_dict is not None:
|
||||
while len(attr_name) > 0:
|
||||
if attr_name in model_prefix_to_model_type:
|
||||
lookup_dict[model_prefix_to_model_type[attr_name]] = True
|
||||
break
|
||||
# Try again after removing the last word in the name
|
||||
attr_name = "".join(camel_case_split(attr_name)[:-1])
|
||||
|
||||
all_models = set(pt_models.keys())
|
||||
all_models = list(all_models)
|
||||
all_models.sort()
|
||||
|
||||
data = {"model_type": all_models}
|
||||
data["pytorch"] = [pt_models[t] for t in all_models]
|
||||
|
||||
# Now let's find the right processing class for each model. In order we check if there is a Processor, then a
|
||||
# Tokenizer, then a FeatureExtractor, then an ImageProcessor
|
||||
processors = {}
|
||||
for t in all_models:
|
||||
if t in transformers_module.models.auto.processing_auto.PROCESSOR_MAPPING_NAMES:
|
||||
processors[t] = "AutoProcessor"
|
||||
elif t in transformers_module.models.auto.tokenization_auto.TOKENIZER_MAPPING_NAMES:
|
||||
processors[t] = "AutoTokenizer"
|
||||
elif t in transformers_module.models.auto.image_processing_auto.IMAGE_PROCESSOR_MAPPING_NAMES:
|
||||
processors[t] = "AutoImageProcessor"
|
||||
elif t in transformers_module.models.auto.feature_extraction_auto.FEATURE_EXTRACTOR_MAPPING_NAMES:
|
||||
processors[t] = "AutoFeatureExtractor"
|
||||
else:
|
||||
# Default to AutoTokenizer if a model has nothing, for backward compatibility.
|
||||
processors[t] = "AutoTokenizer"
|
||||
|
||||
data["processor"] = [processors[t] for t in all_models]
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
def update_pipeline_and_auto_class_table(table: dict[str, tuple[str, str]]) -> dict[str, tuple[str, str]]:
|
||||
"""
|
||||
Update the table mapping models to pipelines and auto classes without removing old keys if they don't exist anymore.
|
||||
|
||||
Args:
|
||||
table (`Dict[str, Tuple[str, str]]`):
|
||||
The existing table mapping model names to a tuple containing the pipeline tag and the auto-class name with
|
||||
which they should be used.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Tuple[str, str]]`: The updated table in the same format.
|
||||
"""
|
||||
module = transformers_module.models.auto.modeling_auto
|
||||
for pipeline_tag, model_mapping, cls in PIPELINE_TAGS_AND_AUTO_MODELS:
|
||||
if not hasattr(module, model_mapping):
|
||||
continue
|
||||
# First extract all model_names
|
||||
model_names = []
|
||||
for name in getattr(module, model_mapping).values():
|
||||
if isinstance(name, str):
|
||||
model_names.append(name)
|
||||
else:
|
||||
model_names.extend(list(name))
|
||||
|
||||
# Add pipeline tag and auto model class for those models
|
||||
table.update(dict.fromkeys(model_names, (pipeline_tag, cls)))
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def update_metadata(token: str, commit_sha: str):
|
||||
"""
|
||||
Update the metadata for the Transformers repo in `huggingface/transformers-metadata`.
|
||||
|
||||
Args:
|
||||
token (`str`): A valid token giving write access to `huggingface/transformers-metadata`.
|
||||
commit_sha (`str`): The commit SHA on Transformers corresponding to this update.
|
||||
"""
|
||||
frameworks_table = get_frameworks_table()
|
||||
frameworks_dataset = Dataset.from_pandas(frameworks_table)
|
||||
|
||||
resolved_tags_file = hf_hub_download(
|
||||
"huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token
|
||||
)
|
||||
tags_dataset = Dataset.from_json(resolved_tags_file)
|
||||
table = {
|
||||
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
|
||||
for i in range(len(tags_dataset))
|
||||
}
|
||||
table = update_pipeline_and_auto_class_table(table)
|
||||
|
||||
# Sort the model classes to avoid some nondeterministic updates to create false update commits.
|
||||
model_classes = sorted(table.keys())
|
||||
tags_table = pd.DataFrame(
|
||||
{
|
||||
"model_class": model_classes,
|
||||
"pipeline_tag": [table[m][0] for m in model_classes],
|
||||
"auto_class": [table[m][1] for m in model_classes],
|
||||
}
|
||||
)
|
||||
tags_dataset = Dataset.from_pandas(tags_table)
|
||||
|
||||
hub_frameworks_json = hf_hub_download(
|
||||
repo_id="huggingface/transformers-metadata",
|
||||
filename="frameworks.json",
|
||||
repo_type="dataset",
|
||||
token=token,
|
||||
)
|
||||
with open(hub_frameworks_json) as f:
|
||||
hub_frameworks_json = f.read()
|
||||
|
||||
hub_pipeline_tags_json = hf_hub_download(
|
||||
repo_id="huggingface/transformers-metadata",
|
||||
filename="pipeline_tags.json",
|
||||
repo_type="dataset",
|
||||
token=token,
|
||||
)
|
||||
with open(hub_pipeline_tags_json) as f:
|
||||
hub_pipeline_tags_json = f.read()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
|
||||
tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
|
||||
|
||||
with open(os.path.join(tmp_dir, "frameworks.json")) as f:
|
||||
frameworks_json = f.read()
|
||||
with open(os.path.join(tmp_dir, "pipeline_tags.json")) as f:
|
||||
pipeline_tags_json = f.read()
|
||||
|
||||
frameworks_equal = hub_frameworks_json == frameworks_json
|
||||
hub_pipeline_tags_equal = hub_pipeline_tags_json == pipeline_tags_json
|
||||
|
||||
if frameworks_equal and hub_pipeline_tags_equal:
|
||||
print("No updates on the Hub, not pushing the metadata files.")
|
||||
return
|
||||
|
||||
if commit_sha is not None:
|
||||
commit_message = (
|
||||
f"Update with commit {commit_sha}\n\nSee: "
|
||||
f"https://github.com/huggingface/transformers/commit/{commit_sha}"
|
||||
)
|
||||
else:
|
||||
commit_message = "Update"
|
||||
|
||||
upload_folder(
|
||||
repo_id="huggingface/transformers-metadata",
|
||||
folder_path=tmp_dir,
|
||||
repo_type="dataset",
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
|
||||
def check_pipeline_tags():
|
||||
"""
|
||||
Check all pipeline tags are properly defined in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant of this script.
|
||||
"""
|
||||
in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
|
||||
pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
|
||||
missing = []
|
||||
for key in pipeline_tasks:
|
||||
if key not in in_table:
|
||||
model = pipeline_tasks[key]["pt"]
|
||||
if isinstance(model, (list, tuple)):
|
||||
model = model[0]
|
||||
model = model.__name__
|
||||
if model not in in_table.values():
|
||||
missing.append(key)
|
||||
|
||||
if len(missing) > 0:
|
||||
msg = ", ".join(missing)
|
||||
raise ValueError(
|
||||
"The following pipeline tags are not present in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant inside "
|
||||
f"`utils/update_metadata.py`: {msg}. Please add them!"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--token", type=str, help="The token to use to push to the transformers-metadata dataset.")
|
||||
parser.add_argument("--commit_sha", type=str, help="The sha of the commit going with this update.")
|
||||
parser.add_argument("--check-only", action="store_true", help="Activate to just check all pipelines are present.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.check_only:
|
||||
check_pipeline_tags()
|
||||
else:
|
||||
update_metadata(args.token, args.commit_sha)
|
||||
171
transformers/utils/update_tiny_models.py
Normal file
171
transformers/utils/update_tiny_models.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A script running `create_dummy_models.py` with a pre-defined set of arguments.
|
||||
|
||||
This file is intended to be used in a CI workflow file without the need of specifying arguments. It creates and uploads
|
||||
tiny models for all model classes (if their tiny versions are not on the Hub yet), as well as produces an updated
|
||||
version of `tests/utils/tiny_model_summary.json`. That updated file should be merged into the `main` branch of
|
||||
`transformers` so the pipeline testing will use the latest created/updated tiny models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
from create_dummy_models import COMPOSITE_MODELS, create_tiny_models
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import transformers
|
||||
from transformers import AutoFeatureExtractor, AutoImageProcessor, AutoTokenizer
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
|
||||
def get_all_model_names():
|
||||
model_names = set()
|
||||
|
||||
module_name = "modeling_auto"
|
||||
module = getattr(transformers.models.auto, module_name, None)
|
||||
if module is not None:
|
||||
# all mappings in a single auto modeling file
|
||||
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES") and x.startswith("MODEL_")]
|
||||
for name in mapping_names:
|
||||
mapping = getattr(module, name)
|
||||
if mapping is not None:
|
||||
for v in mapping.values():
|
||||
if isinstance(v, (list, tuple)):
|
||||
model_names.update(v)
|
||||
elif isinstance(v, str):
|
||||
model_names.add(v)
|
||||
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
def get_tiny_model_names_from_repo():
|
||||
with open("tests/utils/tiny_model_summary.json") as fp:
|
||||
tiny_model_info = json.load(fp)
|
||||
tiny_models_names = set()
|
||||
for model_base_name in tiny_model_info:
|
||||
tiny_models_names.update(tiny_model_info[model_base_name]["model_classes"])
|
||||
|
||||
return sorted(tiny_models_names)
|
||||
|
||||
|
||||
def get_tiny_model_summary_from_hub(output_path):
|
||||
api = HfApi()
|
||||
special_models = COMPOSITE_MODELS.values()
|
||||
|
||||
# All tiny model base names on Hub
|
||||
model_names = get_all_model_names()
|
||||
models = api.list_models(author="hf-internal-testing")
|
||||
_models = set()
|
||||
for x in models:
|
||||
model = x.id
|
||||
org, model = model.split("/")
|
||||
if not model.startswith("tiny-random-"):
|
||||
continue
|
||||
model = model.replace("tiny-random-", "")
|
||||
if not model[0].isupper():
|
||||
continue
|
||||
if model not in model_names and model not in special_models:
|
||||
continue
|
||||
_models.add(model)
|
||||
|
||||
models = sorted(_models)
|
||||
# All tiny model names on Hub
|
||||
summary = {}
|
||||
for model in models:
|
||||
repo_id = f"hf-internal-testing/tiny-random-{model}"
|
||||
model = model.split("-")[0]
|
||||
try:
|
||||
repo_info = api.repo_info(repo_id)
|
||||
content = {
|
||||
"tokenizer_classes": set(),
|
||||
"processor_classes": set(),
|
||||
"model_classes": set(),
|
||||
"sha": repo_info.sha,
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
time.sleep(1)
|
||||
tokenizer_fast = AutoTokenizer.from_pretrained(repo_id)
|
||||
content["tokenizer_classes"].add(tokenizer_fast.__class__.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
time.sleep(1)
|
||||
tokenizer_slow = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
|
||||
content["tokenizer_classes"].add(tokenizer_slow.__class__.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
time.sleep(1)
|
||||
img_p = AutoImageProcessor.from_pretrained(repo_id)
|
||||
content["processor_classes"].add(img_p.__class__.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
time.sleep(1)
|
||||
feat_p = AutoFeatureExtractor.from_pretrained(repo_id)
|
||||
if not isinstance(feat_p, BaseImageProcessor):
|
||||
content["processor_classes"].add(feat_p.__class__.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
time.sleep(1)
|
||||
model_class = getattr(transformers, model)
|
||||
m = model_class.from_pretrained(repo_id)
|
||||
content["model_classes"].add(m.__class__.__name__)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
content["tokenizer_classes"] = sorted(content["tokenizer_classes"])
|
||||
content["processor_classes"] = sorted(content["processor_classes"])
|
||||
content["model_classes"] = sorted(content["model_classes"])
|
||||
|
||||
summary[model] = content
|
||||
with open(os.path.join(output_path, "hub_tiny_model_summary.json"), "w") as fp:
|
||||
json.dump(summary, fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num_workers", default=1, type=int, help="The number of workers to run.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# This has to be `spawn` to avoid hanging forever!
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
output_path = "tiny_models"
|
||||
all = True
|
||||
model_types = None
|
||||
models_to_skip = get_tiny_model_names_from_repo()
|
||||
no_check = True
|
||||
upload = True
|
||||
organization = "hf-internal-testing"
|
||||
|
||||
create_tiny_models(
|
||||
output_path,
|
||||
all,
|
||||
model_types,
|
||||
models_to_skip,
|
||||
no_check,
|
||||
upload,
|
||||
organization,
|
||||
token=os.environ.get("TOKEN", None),
|
||||
num_workers=args.num_workers,
|
||||
)
|
||||
Reference in New Issue
Block a user