#!/usr/bin/env python # Copyright 2022 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # /// script # dependencies = [ # "transformers @ git+https://github.com/huggingface/transformers.git", # "torch>=1.5.0", # "torchvision>=0.6.0", # "datasets>=1.8.0", # ] # /// import logging import os import sys from dataclasses import dataclass, field from typing import Optional import numpy as np import torch from datasets import load_dataset from torchvision.transforms import Compose, Lambda, Normalize, RandomHorizontalFlip, RandomResizedCrop, ToTensor import transformers from transformers import ( CONFIG_MAPPING, IMAGE_PROCESSOR_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, AutoConfig, AutoImageProcessor, AutoModelForMaskedImageModeling, HfArgumentParser, Trainer, TrainingArguments, ) from transformers.trainer_utils import get_last_checkpoint from transformers.utils import check_min_version, send_example_telemetry from transformers.utils.versions import require_version """ Pre-training a 🤗 Transformers model for simple masked image modeling (SimMIM). Any model supported by the AutoModelForMaskedImageModeling API can be used. """ logger = logging.getLogger(__name__) # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.57.0.dev0") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and eval. Using `HfArgumentParser` we can turn this class into argparse arguments to be able to specify them on the command line. """ dataset_name: Optional[str] = field( default="cifar10", metadata={"help": "Name of a dataset from the datasets package"} ) dataset_config_name: Optional[str] = field( default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} ) image_column_name: Optional[str] = field( default=None, metadata={"help": "The column name of the images in the files. If not set, will try to use 'image' or 'img'."}, ) train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."}) validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."}) train_val_split: Optional[float] = field( default=0.15, metadata={"help": "Percent to split off of train for validation."} ) mask_patch_size: int = field(default=32, metadata={"help": "The size of the square patches to use for masking."}) mask_ratio: float = field( default=0.6, metadata={"help": "Percentage of patches to mask."}, ) max_train_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ) }, ) max_eval_samples: Optional[int] = field( default=None, metadata={ "help": ( "For debugging purposes or quicker training, truncate the number of evaluation examples to this " "value if set." ) }, ) def __post_init__(self): data_files = {} if self.train_dir is not None: data_files["train"] = self.train_dir if self.validation_dir is not None: data_files["val"] = self.validation_dir self.data_files = data_files if data_files else None @dataclass class ModelArguments: """ Arguments pertaining to which model/config/image processor we are going to pre-train. """ model_name_or_path: str = field( default=None, metadata={ "help": ( "The model checkpoint for weights initialization. Can be a local path to a pytorch_model.bin or a " "checkpoint identifier on the hub. " "Don't set if you want to train a model from scratch." ) }, ) model_type: Optional[str] = field( default=None, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, ) config_name_or_path: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) config_overrides: Optional[str] = field( default=None, metadata={ "help": ( "Override some existing default config settings when a model is trained from scratch. Example: " "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" ) }, ) cache_dir: Optional[str] = field( default=None, metadata={"help": "Where do you want to store (cache) the pretrained models/datasets downloaded from the hub"}, ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."}) token: str = field( default=None, metadata={ "help": ( "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " "generated when running `hf auth login` (stored in `~/.huggingface`)." ) }, ) trust_remote_code: bool = field( default=False, metadata={ "help": ( "Whether to trust the execution of code from datasets/models defined on the Hub." " This option should only be set to `True` for repositories you trust and in which you have read the" " code, as it will execute code present on the Hub on your local machine." ) }, ) image_size: Optional[int] = field( default=None, metadata={ "help": ( "The size (resolution) of each image. If not specified, will use `image_size` of the configuration." ) }, ) patch_size: Optional[int] = field( default=None, metadata={ "help": ( "The size (resolution) of each patch. If not specified, will use `patch_size` of the configuration." ) }, ) encoder_stride: Optional[int] = field( default=None, metadata={"help": "Stride to use for the encoder."}, ) class MaskGenerator: """ A class to generate boolean masks for the pretraining task. A mask is a 1D tensor of shape (model_patch_size**2,) where the value is either 0 or 1, where 1 indicates "masked". """ def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): self.input_size = input_size self.mask_patch_size = mask_patch_size self.model_patch_size = model_patch_size self.mask_ratio = mask_ratio if self.input_size % self.mask_patch_size != 0: raise ValueError("Input size must be divisible by mask patch size") if self.mask_patch_size % self.model_patch_size != 0: raise ValueError("Mask patch size must be divisible by model patch size") self.rand_size = self.input_size // self.mask_patch_size self.scale = self.mask_patch_size // self.model_patch_size self.token_count = self.rand_size**2 self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) def __call__(self): mask_idx = np.random.permutation(self.token_count)[: self.mask_count] mask = np.zeros(self.token_count, dtype=int) mask[mask_idx] = 1 mask = mask.reshape((self.rand_size, self.rand_size)) mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) return torch.tensor(mask.flatten()) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) mask = torch.stack([example["mask"] for example in examples]) return {"pixel_values": pixel_values, "bool_masked_pos": mask} def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_mim", model_args, data_args) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) if training_args.should_log: # The default of training_args.log_level is passive, so we set log level at info here to have that default. transformers.utils.logging.set_verbosity_info() log_level = training_args.get_process_log_level() logger.setLevel(log_level) transformers.utils.logging.set_verbosity(log_level) transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() # Log on each process the small summary: logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" ) logger.info(f"Training/evaluation parameters {training_args}") # Detecting last checkpoint. last_checkpoint = None if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. " "Use --overwrite_output_dir to overcome." ) elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: logger.info( f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) # Initialize our dataset. ds = load_dataset( data_args.dataset_name, data_args.dataset_config_name, data_files=data_args.data_files, cache_dir=model_args.cache_dir, token=model_args.token, trust_remote_code=model_args.trust_remote_code, ) # If we don't have a validation split, split off a percentage of train as validation. data_args.train_val_split = None if "validation" in ds else data_args.train_val_split if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0: split = ds["train"].train_test_split(data_args.train_val_split) ds["train"] = split["train"] ds["validation"] = split["test"] # Create config # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. config_kwargs = { "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, "token": model_args.token, "trust_remote_code": model_args.trust_remote_code, } if model_args.config_name_or_path: config = AutoConfig.from_pretrained(model_args.config_name_or_path, **config_kwargs) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.config_overrides is not None: logger.info(f"Overriding config: {model_args.config_overrides}") config.update_from_string(model_args.config_overrides) logger.info(f"New config: {config}") # make sure the decoder_type is "simmim" (only relevant for BEiT) if hasattr(config, "decoder_type"): config.decoder_type = "simmim" # adapt config model_args.image_size = model_args.image_size if model_args.image_size is not None else config.image_size model_args.patch_size = model_args.patch_size if model_args.patch_size is not None else config.patch_size model_args.encoder_stride = ( model_args.encoder_stride if model_args.encoder_stride is not None else config.encoder_stride ) config.update( { "image_size": model_args.image_size, "patch_size": model_args.patch_size, "encoder_stride": model_args.encoder_stride, } ) # create image processor if model_args.image_processor_name: image_processor = AutoImageProcessor.from_pretrained(model_args.image_processor_name, **config_kwargs) elif model_args.model_name_or_path: image_processor = AutoImageProcessor.from_pretrained(model_args.model_name_or_path, **config_kwargs) else: IMAGE_PROCESSOR_TYPES = { conf.model_type: image_processor_class for conf, image_processor_class in IMAGE_PROCESSOR_MAPPING.items() } image_processor = IMAGE_PROCESSOR_TYPES[model_args.model_type][-1]() # create model if model_args.model_name_or_path: model = AutoModelForMaskedImageModeling.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, token=model_args.token, trust_remote_code=model_args.trust_remote_code, ) else: logger.info("Training new model from scratch") model = AutoModelForMaskedImageModeling.from_config(config, trust_remote_code=model_args.trust_remote_code) if training_args.do_train: column_names = ds["train"].column_names else: column_names = ds["validation"].column_names if data_args.image_column_name is not None: image_column_name = data_args.image_column_name elif "image" in column_names: image_column_name = "image" elif "img" in column_names: image_column_name = "img" else: image_column_name = column_names[0] # transformations as done in original SimMIM paper # source: https://github.com/microsoft/SimMIM/blob/main/data/data_simmim.py transforms = Compose( [ Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), RandomResizedCrop(model_args.image_size, scale=(0.67, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0)), RandomHorizontalFlip(), ToTensor(), Normalize(mean=image_processor.image_mean, std=image_processor.image_std), ] ) # create mask generator mask_generator = MaskGenerator( input_size=model_args.image_size, mask_patch_size=data_args.mask_patch_size, model_patch_size=model_args.patch_size, mask_ratio=data_args.mask_ratio, ) def preprocess_images(examples): """Preprocess a batch of images by applying transforms + creating a corresponding mask, indicating which patches to mask.""" examples["pixel_values"] = [transforms(image) for image in examples[image_column_name]] examples["mask"] = [mask_generator() for i in range(len(examples[image_column_name]))] return examples if training_args.do_train: if "train" not in ds: raise ValueError("--do_train requires a train dataset") if data_args.max_train_samples is not None: ds["train"] = ds["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples)) # Set the training transforms ds["train"].set_transform(preprocess_images) if training_args.do_eval: if "validation" not in ds: raise ValueError("--do_eval requires a validation dataset") if data_args.max_eval_samples is not None: ds["validation"] = ( ds["validation"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples)) ) # Set the validation transforms ds["validation"].set_transform(preprocess_images) # Initialize our trainer trainer = Trainer( model=model, args=training_args, train_dataset=ds["train"] if training_args.do_train else None, eval_dataset=ds["validation"] if training_args.do_eval else None, processing_class=image_processor, data_collator=collate_fn, ) # Training if training_args.do_train: checkpoint = None if training_args.resume_from_checkpoint is not None: checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint train_result = trainer.train(resume_from_checkpoint=checkpoint) trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() # Evaluation if training_args.do_eval: metrics = trainer.evaluate() trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) # Write model card and (optionally) push to hub kwargs = { "finetuned_from": model_args.model_name_or_path, "tasks": "masked-image-modeling", "dataset": data_args.dataset_name, "tags": ["masked-image-modeling"], } if training_args.push_to_hub: trainer.push_to_hub(**kwargs) else: trainer.create_model_card(**kwargs) if __name__ == "__main__": main()