Skip to content

Latest commit

 

History

History
177 lines (132 loc) · 13.2 KB

README.md

File metadata and controls

177 lines (132 loc) · 13.2 KB

Introduction

Simplified PyTorch implementation of image classification, support multi-gpu training and validating, automatic mixed precision training, knowledge distillation, hyperparameter optimization using Optuna, and different datasets, like CIFAR10, MNIST etc.

Requirements

torch == 1.8.1
torchvision
torchmetrics == 1.2.0
albumentations
loguru
tqdm
timm == 0.6.12 (optional)
optuna == 4.0.0 (optional)
optuna-integration == 4.0.0 (optional)

Supported models

This repo provides modified ResNets and MobileNetV2 if you want to train datasets of small-resolution images, e.g. CIFAR10 (32x32) or MNIST (28x28). You can also train datasets of normal-size images like ImageNet using this repo. Besides ResNets and MobileNetV2, you may also refer to timm3 which provides hundereds of pretrained models. For example, if you want to train mobilenetv3_small from timm, you may change the config file to

config.model = 'timm'
config.timm_model = 'mobilenetv3_small_100'

or use command-line arguments

python main.py --model timm --timm_model mobilenetv3_small_100

Details of the configurations can also be found in this file.

Since most timm models are downsampled 32 times, to retain more details and gain better performances, you may need to modify the downsampling rates of timm model if you want to train datasets of small-resolution images.

Supported datasets

If you want to test other datasets from torchvision, you may refer to this site. Noted that this site is outdated since the version of torchvision(0.9.1) is bounded to torch(1.8.1). If you want to test datasets from newer version of torchvision, you need to update this codebase to be compatible with newer torch. You can also download the image files and build your own dataset following the style of Custom dataset if you don't want to update the codebase.

Knowledge Distillation

Currently only support the original knowledge distillation method proposed by Geoffrey Hinton.7

MixUp

This repo provides batch-wise mixup augmentation.8 You may control the probability of mixup through this parameter config.mixup. If you want to perform mixup for individual images, you may need to implement yourself.

Hyperparameter Optimization

This repo also support hyperparameter optimization using Optuna.9 For example, if you want to search hyperparameters for CIFAR10 dataset using MobileNetv2, you may simply run

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 optuna_search.py

How to use

DDP training (recommend)

CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 main.py

DP training

CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py

Performances

CIFAR10

Model pretrained kd optuna mixup Epoch Accuracy(%)
ResNet50 1.0 200 95.99
ResNet50 1.0 400 96.62 (teacher)
ResNet18 1.0 200 95.34 (base)
ResNet18 0.0 200 94.25 ⬇️
ResNet18 0.5 200 95.08 ⬇️
ResNet18 1.0 200 95.91 ⬆️
ResNet18 1.0 400 95.95 ⬆️
ResNet18 1.0 200 95.69 ⬆️
ResNet18 1.0 400 96.03 ⬆️
ResNet18 1.0 400 96.12 ⬆️
MobileNetv2 1.0 200 94.88 (base)
MobileNetv2 1.0 200 95.21 ⬆️
MobileNetv2 1.0 400 95.37 ⬆️
MobileNetv2 1.0 200 94.83 ⬇️
MobileNetv2 1.0 400 95.29 ⬆️
MobileNetv2 1.0 400 95.12 ⬆️
MobileNetv2 - - config - 100 96.39 ⬆️

CIFAR100

Model pretrained kd optuna mixup Epoch Accuracy(%)
ResNet50 1.0 400 79.52 (teacher)
ResNet18 1.0 200 75.68 (base)
ResNet18 1.0 200 78.89 ⬆️
ResNet18 1.0 400 78.56 ⬇️
ResNet18 1.0 200 75.82 ⬆️
ResNet18 1.0 400 76.53 ⬆️
ResNet18 1.0 400 76.85 ⬆️
MobileNetv2 1.0 200 76.90 (base)
MobileNetv2 1.0 200 78.41 ⬆️
MobileNetv2 1.0 400 78.37 ⬇️
MobileNetv2 1.0 200 76.81 ⬇️
MobileNetv2 1.0 400 77.30 ⬆️
MobileNetv2 1.0 400 77.85 ⬆️
MobileNetv2 - - config - 100 82.01 ⬆️

MNIST

Model pretrained optuna h_flip mixup Epoch Accuracy(%)
ResNet18 0.5 1.0 200 99.65 (base)
ResNet18 0.0 1.0 200 99.65
ResNet18 0.0 1.0 200 99.65
ResNet18 0.5 1.0 200 99.68 ⬆️
ResNet18 0.0 1.0 400 99.67 ⬆️
ResNet18 0.5 1.0 400 99.69 ⬆️
MobileNetv2 0.5 1.0 200 99.67 (base)
MobileNetv2 0.0 1.0 200 99.64 ⬇️
MobileNetv2 0.0 1.0 200 99.68 ⬆️
MobileNetv2 0.5 1.0 200 99.62 ⬇️
MobileNetv2 0.0 1.0 400 99.64 ⬇️
MobileNetv2 0.5 1.0 400 99.65 ⬇️
MobileNetv2 - config - - 100 99.73 ⬆️

Fashion-MNIST

Model pretrained optuna h_flip mixup Epoch Accuracy(%)
ResNet18 0.5 1.0 200 94.33 (base)
ResNet18 0.0 1.0 200 94.30 ⬇️
ResNet18 0.0 1.0 200 94.59 ⬆️
ResNet18 0.5 1.0 200 94.55 ⬆️
ResNet18 0.0 1.0 400 94.20 ⬇️
ResNet18 0.5 1.0 400 94.41 ⬆️
MobileNetv2 0.5 1.0 200 94.81 (base)
MobileNetv2 0.0 1.0 200 94.96 ⬆️
MobileNetv2 0.0 1.0 200 95.28 ⬆️
MobileNetv2 0.5 1.0 200 95.20 ⬆️
MobileNetv2 0.0 1.0 400 95.05 ⬆️
MobileNetv2 0.5 1.0 400 95.21 ⬆️
MobileNetv2 - config - - 100 95.53 ⬆️

References

Footnotes

  1. Deep Residual Learning for Image Recognition

  2. MobileNetV2: Inverted Residuals and Linear Bottlenecks

  3. PyTorch Image Models 2

  4. The CIFAR-10 dataset 2

  5. The MNIST database of handwritten digits

  6. Fashion MNIST

  7. Distilling the Knowledge in a Neural Network

  8. mixup: Beyond Empirical Risk Minimization

  9. Optuna: A hyperparameter optimization framework